fix migrate script and type mismatches in probe, config and loader

This commit is contained in:
Lincoln Stein
2023-09-29 00:09:07 -04:00
parent 81fce18c73
commit 2f16a2c35d
15 changed files with 96 additions and 84 deletions

View File

@ -25,6 +25,7 @@ from pydantic import BaseSettings
class PagingArgumentParser(argparse.ArgumentParser):
"""
A custom ArgumentParser that uses pydoc to page its output.
It also supports reading defaults from an init file.
"""
@ -226,9 +227,7 @@ class InvokeAISettings(BaseSettings):
def int_or_float_or_str(value: str) -> Union[int, float, str]:
"""
Workaround for argparse type checking.
"""
"""Workaround for argparse type checking."""
try:
return int(value)
except Exception as e: # noqa F841

View File

@ -257,7 +257,6 @@ class InvokeAIAppConfig(InvokeAISettings):
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", )
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
# QUEUE
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
@ -454,9 +453,7 @@ class InvokeAIAppConfig(InvokeAISettings):
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
"""
Legacy function which returns InvokeAIAppConfig.get_config()
"""
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
return InvokeAIAppConfig.get_config(**kwargs)

View File

@ -55,20 +55,10 @@ class CacheStats(object):
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
class _CacheRecord:
size: int
model: Any
cache: ModelCache
cache: "ModelCache"
_locks: int
def __init__(self, cache, model: Any, size: int):
@ -130,7 +120,7 @@ class ModelCache(object):
self.logger = logger
# used for stats collection
self.stats = None
self.stats: Optional[CacheStats] = None
self._cached_models = dict()
self._cache_stack = list()
@ -160,7 +150,7 @@ class ModelCache(object):
if model_info_key not in self.model_infos:
self.model_infos[model_info_key] = model_class(
model_path,
model_path.as_posix(),
base_model,
model_type,
)

View File

@ -118,7 +118,7 @@ class ModelConfigBase(BaseModel):
base_model: BaseModelType
model_type: ModelType
model_format: ModelFormat
key: Optional[str] = Field(None) # this will get added by the store
key: str = Field(description="hash key for model", default="<NOKEY>") # this will get added by the store
description: Optional[str] = Field(None)
author: Optional[str] = Field(description="Model author")
license: Optional[str] = Field(description="License string")
@ -233,6 +233,16 @@ class IPAdapterConfig(ModelConfigBase):
model_format: Literal[ModelFormat.InvokeAI]
AnyModelConfig = Union[
MainCheckpointConfig,
MainDiffusersConfig,
LoRAConfig,
TextualInversionConfig,
ONNXSD1Config,
ONNXSD2Config,
]
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""
@ -279,14 +289,7 @@ class ModelConfigFactory(object):
model_data: Union[dict, ModelConfigBase],
key: Optional[str] = None,
dest_class: Optional[Type] = None,
) -> Union[
MainCheckpointConfig,
MainDiffusersConfig,
LoRAConfig,
TextualInversionConfig,
ONNXSD1Config,
ONNXSD2Config,
]:
) -> AnyModelConfig:
"""
Return the appropriate config object from raw dict values.

View File

@ -60,7 +60,7 @@ from pydantic import Field
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import Chdir, InvokeAILogger
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
from .config import (
BaseModelType,
@ -321,7 +321,7 @@ class ModelInstall(ModelInstallBase):
"""Model installer class handles installation from a local path."""
_app_config: InvokeAIAppConfig
_logger: InvokeAILogger
_logger: Logger
_store: ModelConfigStore
_download_queue: DownloadQueueBase
_async_installs: Dict[str, str]
@ -355,7 +355,7 @@ class ModelInstall(ModelInstallBase):
self,
store: Optional[ModelConfigStore] = None,
config: Optional[InvokeAIAppConfig] = None,
logger: Optional[InvokeAILogger] = None,
logger: Optional[Logger] = None,
download: Optional[DownloadQueueBase] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
): # noqa D107 - use base class docstrings

View File

@ -5,16 +5,16 @@ import hashlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import torch
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import InvokeAILogger, choose_precision, choose_torch_device
from invokeai.backend.util import InvokeAILogger, Logger, choose_precision, choose_torch_device
from .cache import CacheStats, ModelCache, ModelLocker
from .cache import CacheStats, ModelCache
from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType
from .download import DownloadEventHandler
from .download import DownloadEventHandler, DownloadQueueBase
from .install import ModelInstall, ModelInstallBase
from .models import MODEL_CLASSES, InvalidModelException, ModelBase
from .storage import ConfigFileVersionMismatchException, ModelConfigStore, get_config_store, migrate_models_store
@ -24,10 +24,10 @@ from .storage import ConfigFileVersionMismatchException, ModelConfigStore, get_c
class ModelInfo:
"""This is a context manager object that is used to intermediate access to a model."""
context: ModelLocker
context: ModelCache.ModelLocker
name: str
base_model: BaseModelType
type: ModelType
type: Union[ModelType, SubModelType]
key: str
location: Union[Path, str]
precision: torch.dtype
@ -73,7 +73,7 @@ class ModelLoadBase(ABC):
@property
@abstractmethod
def logger(self) -> InvokeAILogger:
def logger(self) -> Logger:
"""Return the current logger."""
pass
@ -84,7 +84,7 @@ class ModelLoadBase(ABC):
@property
@abstractmethod
def queue(self) -> str:
def queue(self) -> DownloadQueueBase:
"""Return the download queue object used by this object."""
pass
@ -106,10 +106,7 @@ class ModelLoadBase(ABC):
@abstractmethod
def sync_to_config(self):
"""
Reinitialize the store to sync in-memory and in-disk
versions.
"""
"""Reinitialize the store to sync in-memory and in-disk versions."""
pass
@ -120,7 +117,7 @@ class ModelLoad(ModelLoadBase):
_store: ModelConfigStore
_installer: ModelInstallBase
_cache: ModelCache
_logger: InvokeAILogger
_logger: Logger
_cache_keys: dict
_models_file: Path
@ -195,7 +192,7 @@ class ModelLoad(ModelLoadBase):
return self._installer
@property
def logger(self) -> InvokeAILogger:
def logger(self) -> Logger:
"""Return the current logger."""
return self._logger
@ -205,7 +202,7 @@ class ModelLoad(ModelLoadBase):
return self._app_config
@property
def queue(self) -> str:
def queue(self) -> DownloadQueueBase:
"""Return the download queue object used by this object."""
return self._installer.queue
@ -267,6 +264,7 @@ class ModelLoad(ModelLoadBase):
)
def collect_cache_stats(self, cache_stats: CacheStats):
"""Save CacheStats object for stats collecting."""
self._cache.stats = cache_stats
def resolve_model_path(self, path: Union[Path, str]) -> Path:
@ -283,12 +281,12 @@ class ModelLoad(ModelLoadBase):
def _get_model_path(
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
) -> (Path, bool):
) -> Tuple[Path, bool]:
"""Extract a model's filesystem path from its config.
:return: The fully qualified Path of the module (or submodule).
"""
model_path = model_config.path
model_path = Path(model_config.path)
is_submodel_override = False
# Does the config explicitly override the submodel?
@ -302,5 +300,6 @@ class ModelLoad(ModelLoadBase):
return model_path, is_submodel_override
def sync_to_config(self):
"""Synchronize models on disk to those in memory."""
self._store = get_config_store(self._models_file)
self.installer.scan_models_directory()

View File

@ -177,6 +177,15 @@ class ModelBase(metaclass=ABCMeta):
) -> Any:
raise NotImplementedError()
@classmethod
@abstractmethod
def convert_if_required(
cls,
model_config: ModelConfigBase,
output_path: str,
) -> str:
raise NotImplementedError()
class DiffusersModel(ModelBase):
# child_types: Dict[str, Type]

View File

@ -10,7 +10,7 @@ import json
import re
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, Dict, Optional
import safetensors.torch
import torch
@ -33,8 +33,8 @@ class ModelProbeInfo(BaseModel):
base_type: BaseModelType
format: ModelFormat
hash: str
variant_type: Optional[ModelVariantType] = "normal"
prediction_type: Optional[SchedulerPredictionType] = "v_prediction"
variant_type: Optional[ModelVariantType] = ModelVariantType("normal")
prediction_type: Optional[SchedulerPredictionType] = SchedulerPredictionType("v_prediction")
upcast_attention: Optional[bool] = False
image_size: Optional[int] = None
@ -63,7 +63,7 @@ class ProbeBase(ABC):
"""Base model for probing checkpoint and diffusers-style models."""
@abstractmethod
def get_base_type(self) -> BaseModelType:
def get_base_type(self) -> Optional[BaseModelType]:
"""Return the BaseModelType for the model."""
pass
@ -71,7 +71,7 @@ class ProbeBase(ABC):
"""Return the ModelVariantType for the model."""
pass
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
"""Return the SchedulerPredictionType for the model."""
pass
@ -83,7 +83,7 @@ class ProbeBase(ABC):
class ModelProbe(ModelProbeBase):
"""Class to probe a checkpoint, safetensors or diffusers folder."""
PROBES = {
PROBES: Dict[str, dict] = {
"diffusers": {},
"checkpoint": {},
"onnx": {},
@ -252,7 +252,7 @@ class ModelProbe(ModelProbeBase):
# scan model
scan_result = scan_file_path(model)
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
raise InvalidModelException("The model {model_name} is potentially infected by malware. Aborting import.")
# ##################################################3
@ -263,15 +263,13 @@ class ModelProbe(ModelProbeBase):
class CheckpointProbeBase(ProbeBase):
"""Base class for probing checkpoint-style models."""
def __init__(
self, model: Path, helper: Optional[Callable[[Path], SchedulerPredictionType]] = None
) -> BaseModelType:
def __init__(self, checkpoint_path: Path, helper: Optional[Callable[[Path], SchedulerPredictionType]] = None):
"""Initialize the CheckpointProbeBase object."""
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model)
self.model = model
self.checkpoint_path = checkpoint_path
self.checkpoint = ModelProbe._scan_and_load_checkpoint(checkpoint_path)
self.helper = helper
def get_base_type(self) -> BaseModelType:
def get_base_type(self) -> Optional[BaseModelType]:
"""Return the BaseModelType of a checkpoint-style model."""
pass
@ -281,7 +279,7 @@ class CheckpointProbeBase(ProbeBase):
def get_variant_type(self) -> ModelVariantType:
"""Return the ModelVariantType of a checkpoint-style model."""
model_type = ModelProbe.get_model_type_from_checkpoint(self.model)
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path)
if model_type != ModelType.Main:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
@ -378,13 +376,13 @@ class LoRACheckpointProbe(CheckpointProbeBase):
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"Unsupported LoRA type: {self.model}")
raise InvalidModelException(f"Unsupported LoRA type: {self.checkpoint_path}")
class TextualInversionCheckpointProbe(CheckpointProbeBase):
"""TextualInversion checkpoint prober."""
def get_format(self) -> Optional[str]:
def get_format(self) -> str:
"""Return the format of a TextualInversion emedding."""
return ModelFormat.EmbeddingFile
@ -401,8 +399,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
elif token_dim == 1024:
return BaseModelType.StableDiffusion2
else:
return None
raise InvalidModelException("Unknown base model for {self.checkpoint_path}")
class ControlNetCheckpointProbe(CheckpointProbeBase):
@ -421,18 +418,22 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
elif checkpoint[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
class IPAdapterCheckpointProbe(CheckpointProbeBase):
"""Probe IP adapter models."""
def get_base_type(self) -> BaseModelType:
"""Probe base type."""
raise NotImplementedError()
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
"""Probe ClipVision adapter models."""
def get_base_type(self) -> BaseModelType:
"""Probe base type."""
raise NotImplementedError()
@ -511,6 +512,7 @@ class VaeFolderProbe(FolderProbeBase):
"""Class for probing folder-style models."""
def get_base_type(self) -> BaseModelType:
"""Get base type of model."""
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
@ -542,7 +544,7 @@ class VaeFolderProbe(FolderProbeBase):
class TextualInversionFolderProbe(FolderProbeBase):
"""Probe a HuggingFace-style TextualInversion folder."""
def get_format(self) -> Optional[str]:
def get_format(self) -> str:
"""Return the format of the TextualInversion."""
return ModelFormat.EmbeddingFolder
@ -616,9 +618,11 @@ class IPAdapterFolderProbe(FolderProbeBase):
"""Class for probing IP-Adapter models."""
def get_format(self) -> str:
"""Get format of ip adapter."""
return ModelFormat.InvokeAI.value
def get_base_type(self) -> BaseModelType:
"""Get base type of ip adapter."""
model_file = self.folder_path / "ip_adapter.bin"
if not model_file.exists():
raise InvalidModelException("Unknown IP-Adapter model format.")
@ -636,7 +640,10 @@ class IPAdapterFolderProbe(FolderProbeBase):
class CLIPVisionFolderProbe(FolderProbeBase):
"""Probe for folder-based CLIPVision models."""
def get_base_type(self) -> BaseModelType:
"""Get base type."""
return BaseModelType.Any

View File

@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Set, Union
from ..config import BaseModelType, ModelConfigBase, ModelType
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelType
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.2"
@ -76,9 +76,9 @@ class ModelConfigStore(ABC):
pass
@abstractmethod
def get_model(self, key: str) -> ModelConfigBase:
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the ModelConfigBase instance for the indicated model.
Retrieve the configuration for the indicated model.
:param key: Key of model config to be fetched.

View File

@ -8,12 +8,14 @@ from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from ..config import BaseModelType, MainCheckpointConfig, MainConfig, ModelType
from .base import CONFIG_FILE_VERSION
def migrate_models_store(config: InvokeAIAppConfig):
"""Migrate models from v1 models.yaml to v3.2 models.yaml."""
# avoid circular import
from invokeai.backend.model_manager import DuplicateModelException, InvalidModelException, ModelInstall
from invokeai.backend.model_manager import DuplicateModelException, ModelInstall
from invokeai.backend.model_manager.storage import get_config_store
app_config = InvokeAIAppConfig.get_config()
@ -33,14 +35,17 @@ def migrate_models_store(config: InvokeAIAppConfig):
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
continue
base_type, model_type, model_name = model_key.split("/")
base_type, model_type, model_name = str(model_key).split("/")
new_key = "<NOKEY>"
try:
path = app_config.models_path / stanza["path"]
new_key = installer.register_path(path)
except DuplicateModelException:
# if model already installed, then we just update its info
models = store.search_by_name(model_name=model_name, base_model=base_type, model_type=model_type)
models = store.search_by_name(
model_name=model_name, base_model=BaseModelType(base_type), model_type=ModelType(model_type)
)
if len(models) != 1:
continue
new_key = models[0].key
@ -48,9 +53,9 @@ def migrate_models_store(config: InvokeAIAppConfig):
print(str(excp))
model_info = store.get_model(new_key)
if vae := stanza.get("vae"):
if vae := stanza.get("vae") and isinstance(model_info, MainConfig):
model_info.vae = (app_config.models_path / vae).as_posix()
if model_config := stanza.get("config"):
if model_config := stanza.get("config") and isinstance(model_info, MainCheckpointConfig):
model_info.config = (app_config.root_path / model_config).as_posix()
model_info.description = stanza.get("description")
store.update_model(new_key, model_info)

View File

@ -46,7 +46,7 @@ import threading
from pathlib import Path
from typing import List, Optional, Set, Union
from ..config import BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
from .base import CONFIG_FILE_VERSION, DuplicateModelException, ModelConfigStore, UnknownModelException
@ -350,7 +350,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
self._lock.release()
return self.get_model(key)
def get_model(self, key: str) -> ModelConfigBase:
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the ModelConfigBase instance for the indicated model.

View File

@ -49,7 +49,7 @@ import yaml
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from ..config import BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
from .base import (
CONFIG_FILE_VERSION,
ConfigFileVersionMismatchException,
@ -170,7 +170,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
self._lock.release()
return self.get_model(key)
def get_model(self, key: str) -> ModelConfigBase:
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the ModelConfigBase instance for the indicated model.

View File

@ -1,6 +1,8 @@
"""
Initialization file for invokeai.backend.util
"""
from logging import Logger # noqa: F401
from .attention import auto_detect_slice_size # noqa: F401
from .devices import ( # noqa: F401
CPU_DEVICE,

View File

@ -180,7 +180,7 @@ import socket
import urllib.parse
from abc import abstractmethod
from pathlib import Path
from typing import Optional
from typing import Dict
from invokeai.app.services.config import InvokeAIAppConfig
@ -345,7 +345,7 @@ LOG_FORMATTERS = {
class InvokeAILogger(object):
loggers = dict()
loggers: Dict[str, logging.Logger] = dict()
@classmethod
def get_logger(

View File

@ -12,6 +12,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
populate_graph,
prepare_values_to_insert,
)
from .test_nodes import PromptTestInvocation