mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
loaders for main, controlnet, ip-adapter, clipvision and t2i
This commit is contained in:
parent
8ba5360269
commit
67eb715093
@ -173,7 +173,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import Field
|
||||
|
@ -11,7 +11,14 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager import LoadedModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
LoadedModel,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
||||
|
||||
|
||||
@ -108,7 +115,7 @@ class ModelRecordServiceBase(ABC):
|
||||
Load the indicated model into memory and return a LoadedModel object.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
:param submodel_type: For main (pipeline models), the submodel to fetch
|
||||
:param submodel_type: For main (pipeline models), the submodel to fetch
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
|
@ -42,7 +42,6 @@ Typical usage:
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
@ -56,8 +55,8 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
|
||||
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
@ -72,7 +71,7 @@ from .model_records_base import (
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader]=None):
|
||||
def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader] = None):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
@ -289,7 +288,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
@ -303,7 +304,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
@ -317,7 +320,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
return results
|
||||
|
||||
@property
|
||||
|
@ -1,11 +1,9 @@
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
class Migration6Callback:
|
||||
|
||||
class Migration6Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._recreate_model_triggers(cursor)
|
||||
|
||||
@ -28,6 +26,7 @@ class Migration6Callback:
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_migration_6() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 5 to 6.
|
||||
|
@ -120,6 +120,7 @@ class TqdmEventService(EventServiceBase):
|
||||
elif payload["event"] == "model_install_cancelled":
|
||||
self._logger.warning(f"{source}: installation cancelled")
|
||||
|
||||
|
||||
class InstallHelper(object):
|
||||
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""
|
||||
|
||||
|
@ -139,7 +139,6 @@ def _convert_controlnet_ckpt_and_cache(
|
||||
cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
print(f"DEBUG: controlnet config = {model_config}")
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
weights = app_config.root_path / model_path
|
||||
output_path = Path(output_path)
|
||||
|
@ -13,9 +13,9 @@ from .config import (
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from .load import LoadedModel
|
||||
from .probe import ModelProbe
|
||||
from .search import ModelSearch
|
||||
from .load import LoadedModel
|
||||
|
||||
__all__ = [
|
||||
"AnyModel",
|
||||
|
@ -20,14 +20,16 @@ Validation errors will raise an InvalidModelConfigException error.
|
||||
|
||||
"""
|
||||
import time
|
||||
import torch
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from .onnx_runtime import IAIOnnxRuntimeModel
|
||||
from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
@ -204,6 +206,8 @@ class _MainConfig(ModelConfigBase):
|
||||
|
||||
vae: Optional[str] = Field(default=None)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
ztsnr_training: bool = False
|
||||
|
||||
|
||||
@ -217,8 +221,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
class ONNXSD1Config(_MainConfig):
|
||||
@ -276,6 +278,7 @@ AnyModelConfig = Union[
|
||||
_ONNXConfig,
|
||||
_VaeConfig,
|
||||
_ControlNetConfig,
|
||||
# ModelConfigBase,
|
||||
LoRAConfig,
|
||||
TextualInversionConfig,
|
||||
IPAdapterConfig,
|
||||
@ -284,7 +287,7 @@ AnyModelConfig = Union[
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel]
|
||||
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus]
|
||||
|
||||
# IMPLEMENTATION NOTE:
|
||||
# The preferred alternative to the above is a discriminated Union as shown
|
||||
@ -317,7 +320,7 @@ class ModelConfigFactory(object):
|
||||
model_data: Union[dict, AnyModelConfig],
|
||||
key: Optional[str] = None,
|
||||
dest_class: Optional[Type] = None,
|
||||
timestamp: Optional[float] = None
|
||||
timestamp: Optional[float] = None,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Return the appropriate config object from raw dict values.
|
||||
|
@ -43,7 +43,6 @@ from diffusers.schedulers import (
|
||||
UnCLIPScheduler,
|
||||
)
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
@ -58,8 +57,8 @@ from transformers import (
|
||||
)
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
@ -1643,7 +1642,6 @@ def download_controlnet_from_original_ckpt(
|
||||
cross_attention_dim: Optional[bool] = None,
|
||||
scan_needed: bool = False,
|
||||
) -> DiffusionPipeline:
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if from_safetensors:
|
||||
|
@ -8,14 +8,15 @@ from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||
from .load_base import AnyModelLoader, LoadedModel
|
||||
from .model_cache.model_cache_default import ModelCache
|
||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||
|
||||
# This registers the subclasses that implement loaders of specific model types
|
||||
loaders = [x.stem for x in Path(Path(__file__).parent,'model_loaders').glob('*.py') if x.stem != '__init__']
|
||||
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
|
||||
for module in loaders:
|
||||
print(f'module={module}')
|
||||
print(f"module={module}")
|
||||
import_module(f"{__package__}.model_loaders.{module}")
|
||||
|
||||
__all__ = ["AnyModelLoader", "LoadedModel"]
|
||||
@ -24,12 +25,11 @@ __all__ = ["AnyModelLoader", "LoadedModel"]
|
||||
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
|
||||
app_config = app_config or InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
return AnyModelLoader(app_config=app_config,
|
||||
logger=logger,
|
||||
ram_cache=ModelCache(logger=logger,
|
||||
max_cache_size=app_config.ram_cache_size,
|
||||
max_vram_cache_size=app_config.vram_cache_size
|
||||
),
|
||||
convert_cache=ModelConvertCache(app_config.models_convert_cache_path)
|
||||
)
|
||||
|
||||
return AnyModelLoader(
|
||||
app_config=app_config,
|
||||
logger=logger,
|
||||
ram_cache=ModelCache(
|
||||
logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size
|
||||
),
|
||||
convert_cache=ModelConvertCache(app_config.models_convert_cache_path),
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .convert_cache_base import ModelConvertCacheBase
|
||||
from .convert_cache_default import ModelConvertCache
|
||||
|
||||
__all__ = ['ModelConvertCacheBase', 'ModelConvertCache']
|
||||
__all__ = ["ModelConvertCacheBase", "ModelConvertCache"]
|
||||
|
@ -4,8 +4,8 @@ Disk-based converted model cache.
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
class ModelConvertCacheBase(ABC):
|
||||
|
||||
class ModelConvertCacheBase(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_size(self) -> float:
|
||||
@ -25,4 +25,3 @@ class ModelConvertCacheBase(ABC):
|
||||
def cache_path(self, key: str) -> Path:
|
||||
"""Return the path for a model with the indicated key."""
|
||||
pass
|
||||
|
||||
|
@ -2,15 +2,17 @@
|
||||
Placeholder for convert cache implementation.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.backend.util import GIG, directory_size
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .convert_cache_base import ModelConvertCacheBase
|
||||
|
||||
class ModelConvertCache(ModelConvertCacheBase):
|
||||
|
||||
def __init__(self, cache_path: Path, max_size: float=10.0):
|
||||
class ModelConvertCache(ModelConvertCacheBase):
|
||||
def __init__(self, cache_path: Path, max_size: float = 10.0):
|
||||
"""Initialize the convert cache with the base directory and a limit on its maximum size (in GBs)."""
|
||||
if not cache_path.exists():
|
||||
cache_path.mkdir(parents=True)
|
||||
|
@ -10,17 +10,19 @@ Use like this:
|
||||
# do something with loaded_model
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLockerBase
|
||||
from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig
|
||||
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel:
|
||||
@ -52,7 +54,7 @@ class ModelLoaderBase(ABC):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
@ -91,7 +93,7 @@ class AnyModelLoader:
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
):
|
||||
"""Initialize AnyModelLoader with its dependencies."""
|
||||
@ -101,11 +103,11 @@ class AnyModelLoader:
|
||||
self._convert_cache = convert_cache
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase:
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the RAM cache associated used by the loaders."""
|
||||
return self._ram_cache
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType]=None) -> LoadedModel:
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Return a model given its configuration.
|
||||
|
||||
@ -113,9 +115,7 @@ class AnyModelLoader:
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
implementation = self.__class__.get_implementation(
|
||||
base=model_config.base, type=model_config.type, format=model_config.format
|
||||
)
|
||||
implementation, model_config, submodel_type = self.__class__.get_implementation(model_config, submodel_type)
|
||||
return implementation(
|
||||
app_config=self._app_config,
|
||||
logger=self._logger,
|
||||
@ -128,16 +128,37 @@ class AnyModelLoader:
|
||||
return "-".join([base.value, type.value, format.value])
|
||||
|
||||
@classmethod
|
||||
def get_implementation(cls, base: BaseModelType, type: ModelType, format: ModelFormat) -> Type[ModelLoaderBase]:
|
||||
def get_implementation(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]:
|
||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||
key1 = cls._to_registry_key(base, type, format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, type, format) # with wildcard Any
|
||||
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||
|
||||
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
|
||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||
if not implementation:
|
||||
raise NotImplementedError(
|
||||
f"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}"
|
||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||
)
|
||||
return implementation
|
||||
return implementation, conf2, submodel_type
|
||||
|
||||
@classmethod
|
||||
def _handle_subtype_overrides(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[AnyModelConfig, Optional[SubModelType]]:
|
||||
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||
model_path = Path(config.vae)
|
||||
config_class = (
|
||||
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
|
||||
)
|
||||
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
|
||||
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
|
||||
submodel_type = None
|
||||
else:
|
||||
new_conf = config
|
||||
return new_conf, submodel_type
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
@ -152,4 +173,3 @@ class AnyModelLoader:
|
||||
return subclass
|
||||
|
||||
return decorator
|
||||
|
||||
|
@ -10,12 +10,12 @@ from diffusers import ModelMixin
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs, calc_model_size_by_data
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
@ -47,7 +47,6 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._ram_cache = ram_cache
|
||||
self._convert_cache = convert_cache
|
||||
self._torch_dtype = torch_dtype(choose_torch_device())
|
||||
self._size: Optional[int] = None # model size
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
@ -63,9 +62,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
if model_config.type == "main" and not submodel_type:
|
||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||
|
||||
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||
if is_submodel_override:
|
||||
submodel_type = None
|
||||
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
||||
|
||||
if not model_path.exists():
|
||||
raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}")
|
||||
@ -74,13 +71,12 @@ class ModelLoader(ModelLoaderBase):
|
||||
locker = self._load_if_needed(model_config, model_path, submodel_type)
|
||||
return LoadedModel(config=model_config, locker=locker)
|
||||
|
||||
# IMPORTANT: This needs to be overridden in the StableDiffusion subclass so as to handle vae overrides
|
||||
# and submodels!!!!
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, bool]:
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
model_base = self._app_config.models_path
|
||||
return ((model_base / config.path).resolve(), False)
|
||||
result = (model_base / config.path).resolve(), config, submodel_type
|
||||
return result
|
||||
|
||||
def _convert_if_needed(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
@ -90,7 +86,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
if not self._needs_conversion(config, model_path, cache_path):
|
||||
return cache_path if cache_path.exists() else model_path
|
||||
|
||||
self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type))
|
||||
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
return self._convert_model(config, model_path, cache_path)
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
|
||||
@ -114,6 +110,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
config.key,
|
||||
submodel_type=submodel_type,
|
||||
model=loaded_model,
|
||||
size=calc_model_size_by_data(loaded_model),
|
||||
)
|
||||
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
@ -128,17 +125,6 @@ class ModelLoader(ModelLoaderBase):
|
||||
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
||||
)
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
raise NotImplementedError
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
raise NotImplementedError
|
||||
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
||||
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||
|
||||
@ -161,3 +147,17 @@ class ModelLoader(ModelLoaderBase):
|
||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||
class_name = config["_class_name"]
|
||||
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
raise NotImplementedError
|
||||
|
||||
# This needs to be implemented in the subclass
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -97,4 +97,4 @@ def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: O
|
||||
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
||||
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
|
||||
|
||||
return msg
|
||||
return "\n"+msg if len(msg)>0 else msg
|
||||
|
@ -2,4 +2,4 @@
|
||||
|
||||
from .model_cache_base import ModelCacheBase
|
||||
from .model_cache_default import ModelCache
|
||||
_all__ = ['ModelCacheBase', 'ModelCache']
|
||||
_all__ = ["ModelCacheBase", "ModelCache"]
|
||||
|
@ -8,14 +8,15 @@ model will be cleared and (re)loaded from disk when next needed.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from logging import Logger
|
||||
from typing import Dict, Optional, TypeVar, Generic
|
||||
from typing import Generic, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
|
||||
|
||||
class ModelLockerBase(ABC):
|
||||
"""Base class for the model locker used by the loader."""
|
||||
|
||||
@ -35,8 +36,10 @@ class ModelLockerBase(ABC):
|
||||
"""Return the model."""
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheRecord(Generic[T]):
|
||||
"""Elements of the cache."""
|
||||
@ -115,6 +118,7 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
self,
|
||||
key: str,
|
||||
model: T,
|
||||
size: int,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
|
@ -19,22 +19,24 @@ context. Use like this:
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from logging import Logger
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModel
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_cache_base import CacheRecord, ModelCacheBase
|
||||
from .model_locker import ModelLockerBase, ModelLocker
|
||||
from .model_locker import ModelLocker, ModelLockerBase
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
@ -91,7 +93,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
|
||||
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
@ -141,14 +143,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self,
|
||||
key: str,
|
||||
model: AnyModel,
|
||||
size: int,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
assert key not in self._cached_models
|
||||
|
||||
loaded_model_size = calc_model_size_by_data(model)
|
||||
cache_record = CacheRecord(key, model, loaded_model_size)
|
||||
cache_record = CacheRecord(key, model, size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
@ -195,28 +197,32 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
if not cache_entry.locked:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM now available for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
# TO DO: Only reason to pass the CacheRecord rather than the model is to get the key and size
|
||||
# for printing debugging messages. Revisit whether this is necessary
|
||||
def move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device."""
|
||||
# These attributes are not in the base class but in derived classes
|
||||
assert hasattr(cache_entry.model, "device")
|
||||
assert hasattr(cache_entry.model, "to")
|
||||
# These attributes are not in the base ModelMixin class but in derived classes.
|
||||
# Some models don't have these attributes, in which case they run in RAM/CPU.
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
||||
return
|
||||
|
||||
source_device = cache_entry.model.device
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support
|
||||
# multi-GPU.
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
@ -227,8 +233,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n"
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
@ -291,7 +297,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
f" {(bytes_needed/GIG):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
|
||||
pos = 0
|
||||
models_cleared = 0
|
||||
@ -336,7 +342,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
# 1 from onnx runtime object
|
||||
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
||||
self.logger.debug(
|
||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
@ -365,4 +371,4 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
@ -2,9 +2,10 @@
|
||||
Base class and implementation of a class that moves models in and out of VRAM.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from .model_cache_base import ModelLockerBase, ModelCacheBase, CacheRecord
|
||||
|
||||
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
|
||||
|
||||
|
||||
class ModelLocker(ModelLockerBase):
|
||||
"""Internal class that mediates movement in and out of GPU."""
|
||||
@ -56,4 +57,3 @@ class ModelLocker(ModelLockerBase):
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.print_cuda_stats()
|
||||
|
||||
|
@ -0,0 +1,60 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for ControlNet model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
class ControlnetLoader(GenericDiffusersLoader):
|
||||
"""Class to load ControlNet models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert hasattr(config, 'config')
|
||||
config_file = config.config
|
||||
|
||||
if weights_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(weights_path, map_location="cpu")
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
convert_controlnet_to_diffusers(
|
||||
weights_path,
|
||||
output_path,
|
||||
original_config_file=self._app_config.root_path / config_file,
|
||||
image_size=512,
|
||||
scan_needed=True,
|
||||
from_safetensors=weights_path.suffix == ".safetensors",
|
||||
)
|
||||
return output_path
|
@ -0,0 +1,34 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for simple diffusers model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
||||
class GenericDiffusersLoader(ModelLoader):
|
||||
"""Class to load simple diffusers models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
model_class = self._get_hf_load_class(model_path)
|
||||
if submodel_type is not None:
|
||||
raise Exception(f"There are no submodels in models of type {model_class}")
|
||||
variant = model_variant.value if model_variant else None
|
||||
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
|
||||
return result
|
@ -0,0 +1,39 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for IP Adapter model loading in InvokeAI."""
|
||||
|
||||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||
class IPAdapterInvokeAILoader(ModelLoader):
|
||||
"""Class to load IP Adapter diffusers models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||
model = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
|
||||
device=torch.device("cpu"),
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
return model
|
||||
|
76
invokeai/backend/model_manager/load/model_loaders/lora.py
Normal file
76
invokeai/backend/model_manager/load/model_loaders/lora.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for LoRA model loading in InvokeAI."""
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
||||
class LoraLoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
# We cheat a little bit to get access to the model base
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
super().__init__(app_config, logger, ram_cache, convert_cache)
|
||||
self._model_base: Optional[BaseModelType] = None
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in a LoRA model.")
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
dtype=self._torch_dtype,
|
||||
base_model=self._model_base,
|
||||
)
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
self._model_base = config.base # cheating a little - setting this variable for later call to _load_model()
|
||||
|
||||
model_base_path = self._app_config.models_path
|
||||
model_path = model_base_path / config.path
|
||||
|
||||
if config.format == ModelFormat.Diffusers:
|
||||
for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder
|
||||
path = model_base_path / config.path / f"pytorch_lora_weights.{ext}"
|
||||
if path.exists():
|
||||
model_path = path
|
||||
break
|
||||
|
||||
result = model_path.resolve(), config, submodel_type
|
||||
return result
|
||||
|
||||
|
@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for StableDiffusion model loading in InvokeAI."""
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
class StableDiffusionDiffusersModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
model_base_to_model_type = {
|
||||
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusionXL: "SDXL",
|
||||
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||
}
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
raise Exception("A submodel type must be provided when loading main pipelines.")
|
||||
load_class = self._get_hf_load_class(model_path, submodel_type)
|
||||
variant = model_variant.value if model_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
variant=variant,
|
||||
) # type: ignore
|
||||
return result
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
variant = config.variant
|
||||
base = config.base
|
||||
pipeline_class = (
|
||||
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
|
||||
)
|
||||
|
||||
config_file = config.config
|
||||
|
||||
self._logger.info(f"Converting {weights_path} to diffusers format")
|
||||
convert_ckpt_to_diffusers(
|
||||
weights_path,
|
||||
output_path,
|
||||
model_type=self.model_base_to_model_type[base],
|
||||
model_version=base,
|
||||
model_variant=variant,
|
||||
original_config_file=self._app_config.root_path / config_file,
|
||||
extract_ema=True,
|
||||
scan_needed=True,
|
||||
pipeline_class=pipeline_class,
|
||||
from_safetensors=weights_path.suffix == ".safetensors",
|
||||
precision=self._torch_dtype,
|
||||
load_safety_checker=False,
|
||||
)
|
||||
return output_path
|
@ -2,68 +2,54 @@
|
||||
"""Class for VAE model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import safetensors
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
||||
@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
class VaeDiffusersModel(ModelLoader):
|
||||
class VaeLoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise Exception("There are no submodels in VAEs")
|
||||
vae_class = self._get_hf_load_class(model_path)
|
||||
variant = model_variant.value if model_variant else None
|
||||
result: AnyModel = vae_class.from_pretrained(
|
||||
model_path, torch_dtype=self._torch_dtype, variant=variant
|
||||
) # type: ignore
|
||||
return result
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
print(f'DEBUG: last_modified={config.last_modified}')
|
||||
print(f'DEBUG: cache_path={(dest_path / "config.json").stat().st_mtime}')
|
||||
print(f'DEBUG: model_path={model_path.stat().st_mtime}')
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
return False
|
||||
elif dest_path.exists() \
|
||||
and (dest_path / "config.json").stat().st_mtime >= config.last_modified \
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime:
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self,
|
||||
config: AnyModelConfig,
|
||||
weights_path: Path,
|
||||
output_path: Path
|
||||
) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
# TO DO: check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
config_file = 'v1-inference.yaml' if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
|
||||
config_file = (
|
||||
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
|
||||
)
|
||||
|
||||
if weights_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(weights_path, map_location="cpu")
|
||||
|
||||
dtype = torch_dtype()
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
@ -71,13 +57,11 @@ class VaeDiffusersModel(ModelLoader):
|
||||
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
|
||||
assert isinstance(ckpt_config, DictConfig)
|
||||
|
||||
print(f'DEBUG: CONVERTIGN')
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
checkpoint=checkpoint,
|
||||
vae_config=ckpt_config,
|
||||
image_size=512,
|
||||
)
|
||||
vae_model.to(dtype) # set precision appropriately
|
||||
vae_model.save_pretrained(output_path, safe_serialization=True, torch_dtype=dtype)
|
||||
vae_model.to(self._torch_dtype) # set precision appropriately
|
||||
vae_model.save_pretrained(output_path, safe_serialization=True)
|
||||
return output_path
|
||||
|
||||
|
@ -8,10 +8,11 @@ from typing import Optional, Union
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel
|
||||
|
||||
|
||||
def calc_model_size_by_data(model: Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel]) -> int:
|
||||
def calc_model_size_by_data(model: AnyModel) -> int:
|
||||
"""Get size of a model in memory in bytes."""
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
return _calc_pipeline_by_data(model)
|
||||
@ -50,7 +51,7 @@ def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, var
|
||||
"""Estimate the size of a model on disk in bytes."""
|
||||
if model_path.is_file():
|
||||
return model_path.stat().st_size
|
||||
|
||||
|
||||
if subfolder is not None:
|
||||
model_path = model_path / subfolder
|
||||
|
||||
|
620
invokeai/backend/model_manager/lora.py
Normal file
620
invokeai/backend/model_manager/lora.py
Normal file
@ -0,0 +1,620 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
"""LoRA model support."""
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union, List, Tuple
|
||||
from typing_extensions import Self
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
# layer_key: str
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
# up: torch.Tensor
|
||||
# mid: Optional[torch.Tensor]
|
||||
# down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
if "lora_mid.weight" in values:
|
||||
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
||||
else:
|
||||
self.mid = None
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor]
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
|
||||
if "hada_t1" in values:
|
||||
self.t1: Optional[torch.Tensor] = values["hada_t1"]
|
||||
else:
|
||||
self.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
self.t2: Optional[torch.Tensor] = values["hada_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
if "lokr_w1" in values:
|
||||
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
||||
self.w1_a = None
|
||||
self.w1_b = None
|
||||
else:
|
||||
self.w1 = None
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
|
||||
if "lokr_w2" in values:
|
||||
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
else:
|
||||
self.w2 = None
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
|
||||
if "lokr_t2" in values:
|
||||
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
if "lokr_w1_b" in values:
|
||||
self.rank = values["lokr_w1_b"].shape[0]
|
||||
elif "lokr_w2_b" in values:
|
||||
self.rank = values["lokr_w2_b"].shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
|
||||
if len(values.keys()) > 1:
|
||||
_keys = list(values.keys())
|
||||
_keys.remove("diff")
|
||||
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
|
||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||
class LoRAModelRaw: # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, AnyLoRALayer],
|
||||
):
|
||||
self._name = name
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
) -> Self:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem, # TODO:
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(state_dict)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "weight" in values and "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
else:
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map() -> List[Tuple[str,str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
@ -29,8 +29,12 @@ CkptType = Dict[str, Any]
|
||||
|
||||
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: "v1-inference.yaml",
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: "v1-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: {
|
||||
|
@ -12,14 +12,22 @@ from .devices import ( # noqa: F401
|
||||
torch_dtype,
|
||||
)
|
||||
from .logging import InvokeAILogger
|
||||
from .util import ( # TO DO: Clean this up; remove the unused symbols
|
||||
from .util import ( # TO DO: Clean this up; remove the unused symbols
|
||||
GIG,
|
||||
Chdir,
|
||||
ask_user, # noqa
|
||||
directory_size,
|
||||
download_with_resume,
|
||||
instantiate_from_config, # noqa
|
||||
instantiate_from_config, # noqa
|
||||
url_attachment_name, # noqa
|
||||
)
|
||||
)
|
||||
|
||||
__all__ = ["GIG", "directory_size","Chdir", "download_with_resume", "InvokeAILogger", "choose_precision", "choose_torch_device"]
|
||||
__all__ = [
|
||||
"GIG",
|
||||
"directory_size",
|
||||
"Chdir",
|
||||
"download_with_resume",
|
||||
"InvokeAILogger",
|
||||
"choose_precision",
|
||||
"choose_torch_device",
|
||||
]
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Union, Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import autocast
|
||||
|
@ -27,6 +27,7 @@ from .devices import torch_dtype
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
def directory_size(directory: Path) -> int:
|
||||
"""
|
||||
Return the aggregate size of all files in a directory (bytes).
|
||||
@ -39,6 +40,7 @@ def directory_size(directory: Path) -> int:
|
||||
sum += Path(root, d).stat().st_size
|
||||
return sum
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
|
Loading…
Reference in New Issue
Block a user