loaders for main, controlnet, ip-adapter, clipvision and t2i

This commit is contained in:
Lincoln Stein 2024-02-04 17:23:10 -05:00
parent 9804cb0e67
commit 420f6050a6
33 changed files with 1131 additions and 168 deletions

View File

@ -4,10 +4,10 @@ from logging import Logger
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache
from invokeai.backend.model_manager.load.model_cache import ModelCache from invokeai.backend.model_manager.load.model_cache import ModelCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__ from invokeai.version.invokeai_version import __version__
from ..services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage from ..services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
@ -90,15 +90,14 @@ class ApiDependencies:
model_loader = AnyModelLoader( model_loader = AnyModelLoader(
app_config=config, app_config=config,
logger=logger, logger=logger,
ram_cache=ModelCache(max_cache_size=config.ram_cache_size, ram_cache=ModelCache(
max_vram_cache_size=config.vram_cache_size, max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
logger=logger), ),
convert_cache=ModelConvertCache( convert_cache=ModelConvertCache(
cache_path = config.models_convert_cache_path, cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
max_size = config.convert_cache_size ),
) )
) model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader)
model_record_service = ModelRecordServiceSQL(db=db,loader=model_loader)
download_queue_service = DownloadQueueService(event_bus=events) download_queue_service = DownloadQueueService(event_bus=events)
model_install_service = ModelInstallService( model_install_service = ModelInstallService(
app_config=config, app_config=config,

View File

@ -173,7 +173,7 @@ from __future__ import annotations
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_type_hints from typing import Any, ClassVar, Dict, List, Literal, Optional, get_type_hints
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from pydantic import Field, TypeAdapter from pydantic import Field, TypeAdapter

View File

@ -11,7 +11,14 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager import LoadedModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.model_manager import (
AnyModelConfig,
BaseModelType,
LoadedModel,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore

View File

@ -42,7 +42,6 @@ Typical usage:
import json import json
import sqlite3 import sqlite3
import time
from math import ceil from math import ceil
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import Any, Dict, List, Optional, Set, Tuple, Union
@ -56,8 +55,8 @@ from invokeai.backend.model_manager.config import (
ModelType, ModelType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
from ..shared.sqlite.sqlite_database import SqliteDatabase from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import ( from .model_records_base import (
@ -72,7 +71,7 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase): class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database.""" """Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader]=None): def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader] = None):
""" """
Initialize a new object from preexisting sqlite3 connection and threading lock objects. Initialize a new object from preexisting sqlite3 connection and threading lock objects.
@ -289,7 +288,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
""", """,
tuple(bindings), tuple(bindings),
) )
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
return results return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
@ -303,7 +304,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
""", """,
(str(path),), (str(path),),
) )
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
return results return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]: def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
@ -317,7 +320,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
""", """,
(hash,), (hash,),
) )
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
return results return results
@property @property

View File

@ -1,11 +1,9 @@
import sqlite3 import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration6Callback:
class Migration6Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None: def __call__(self, cursor: sqlite3.Cursor) -> None:
self._recreate_model_triggers(cursor) self._recreate_model_triggers(cursor)
@ -28,6 +26,7 @@ class Migration6Callback:
""" """
) )
def build_migration_6() -> Migration: def build_migration_6() -> Migration:
""" """
Build the migration from database version 5 to 6. Build the migration from database version 5 to 6.

View File

@ -120,6 +120,7 @@ class TqdmEventService(EventServiceBase):
elif payload["event"] == "model_install_cancelled": elif payload["event"] == "model_install_cancelled":
self._logger.warning(f"{source}: installation cancelled") self._logger.warning(f"{source}: installation cancelled")
class InstallHelper(object): class InstallHelper(object):
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db.""" """Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""

View File

@ -139,7 +139,6 @@ def _convert_controlnet_ckpt_and_cache(
cache it to disk, and return Path to converted cache it to disk, and return Path to converted
file. If already on disk then just returns Path. file. If already on disk then just returns Path.
""" """
print(f"DEBUG: controlnet config = {model_config}")
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_path / model_path weights = app_config.root_path / model_path
output_path = Path(output_path) output_path = Path(output_path)

View File

@ -13,9 +13,9 @@ from .config import (
SchedulerPredictionType, SchedulerPredictionType,
SubModelType, SubModelType,
) )
from .load import LoadedModel
from .probe import ModelProbe from .probe import ModelProbe
from .search import ModelSearch from .search import ModelSearch
from .load import LoadedModel
__all__ = [ __all__ = [
"AnyModel", "AnyModel",

View File

@ -20,14 +20,16 @@ Validation errors will raise an InvalidModelConfigException error.
""" """
import time import time
import torch
from enum import Enum from enum import Enum
from typing import Literal, Optional, Type, Union from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter import torch
from diffusers import ModelMixin from diffusers import ModelMixin
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated, Any, Dict from typing_extensions import Annotated, Any, Dict
from .onnx_runtime import IAIOnnxRuntimeModel from .onnx_runtime import IAIOnnxRuntimeModel
from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
class InvalidModelConfigException(Exception): class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format.""" """Exception for when config parser doesn't recognized this combination of model type and format."""
@ -204,6 +206,8 @@ class _MainConfig(ModelConfigBase):
vae: Optional[str] = Field(default=None) vae: Optional[str] = Field(default=None)
variant: ModelVariantType = ModelVariantType.Normal variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
ztsnr_training: bool = False ztsnr_training: bool = False
@ -217,8 +221,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models.""" """Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main type: Literal[ModelType.Main] = ModelType.Main
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD1Config(_MainConfig): class ONNXSD1Config(_MainConfig):
@ -276,6 +278,7 @@ AnyModelConfig = Union[
_ONNXConfig, _ONNXConfig,
_VaeConfig, _VaeConfig,
_ControlNetConfig, _ControlNetConfig,
# ModelConfigBase,
LoRAConfig, LoRAConfig,
TextualInversionConfig, TextualInversionConfig,
IPAdapterConfig, IPAdapterConfig,
@ -284,7 +287,7 @@ AnyModelConfig = Union[
] ]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig) AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel] AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus]
# IMPLEMENTATION NOTE: # IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown # The preferred alternative to the above is a discriminated Union as shown
@ -317,7 +320,7 @@ class ModelConfigFactory(object):
model_data: Union[dict, AnyModelConfig], model_data: Union[dict, AnyModelConfig],
key: Optional[str] = None, key: Optional[str] = None,
dest_class: Optional[Type] = None, dest_class: Optional[Type] = None,
timestamp: Optional[float] = None timestamp: Optional[float] = None,
) -> AnyModelConfig: ) -> AnyModelConfig:
""" """
Return the appropriate config object from raw dict values. Return the appropriate config object from raw dict values.

View File

@ -43,7 +43,6 @@ from diffusers.schedulers import (
UnCLIPScheduler, UnCLIPScheduler,
) )
from diffusers.utils import is_accelerate_available from diffusers.utils import is_accelerate_available
from diffusers.utils.import_utils import BACKENDS_MAPPING
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from transformers import ( from transformers import (
AutoFeatureExtractor, AutoFeatureExtractor,
@ -58,8 +57,8 @@ from transformers import (
) )
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.model_manager import BaseModelType, ModelVariantType from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.util.logging import InvokeAILogger
try: try:
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -1643,7 +1642,6 @@ def download_controlnet_from_original_ckpt(
cross_attention_dim: Optional[bool] = None, cross_attention_dim: Optional[bool] = None,
scan_needed: bool = False, scan_needed: bool = False,
) -> DiffusionPipeline: ) -> DiffusionPipeline:
from omegaconf import OmegaConf from omegaconf import OmegaConf
if from_safetensors: if from_safetensors:

View File

@ -8,14 +8,15 @@ from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from .convert_cache.convert_cache_default import ModelConvertCache
from .load_base import AnyModelLoader, LoadedModel from .load_base import AnyModelLoader, LoadedModel
from .model_cache.model_cache_default import ModelCache from .model_cache.model_cache_default import ModelCache
from .convert_cache.convert_cache_default import ModelConvertCache
# This registers the subclasses that implement loaders of specific model types # This registers the subclasses that implement loaders of specific model types
loaders = [x.stem for x in Path(Path(__file__).parent,'model_loaders').glob('*.py') if x.stem != '__init__'] loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
for module in loaders: for module in loaders:
print(f'module={module}') print(f"module={module}")
import_module(f"{__package__}.model_loaders.{module}") import_module(f"{__package__}.model_loaders.{module}")
__all__ = ["AnyModelLoader", "LoadedModel"] __all__ = ["AnyModelLoader", "LoadedModel"]
@ -24,12 +25,11 @@ __all__ = ["AnyModelLoader", "LoadedModel"]
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader: def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
app_config = app_config or InvokeAIAppConfig.get_config() app_config = app_config or InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger(config=app_config) logger = InvokeAILogger.get_logger(config=app_config)
return AnyModelLoader(app_config=app_config, return AnyModelLoader(
app_config=app_config,
logger=logger, logger=logger,
ram_cache=ModelCache(logger=logger, ram_cache=ModelCache(
max_cache_size=app_config.ram_cache_size, logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size
max_vram_cache_size=app_config.vram_cache_size
), ),
convert_cache=ModelConvertCache(app_config.models_convert_cache_path) convert_cache=ModelConvertCache(app_config.models_convert_cache_path),
) )

View File

@ -1,4 +1,4 @@
from .convert_cache_base import ModelConvertCacheBase from .convert_cache_base import ModelConvertCacheBase
from .convert_cache_default import ModelConvertCache from .convert_cache_default import ModelConvertCache
__all__ = ['ModelConvertCacheBase', 'ModelConvertCache'] __all__ = ["ModelConvertCacheBase", "ModelConvertCache"]

View File

@ -4,8 +4,8 @@ Disk-based converted model cache.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
class ModelConvertCacheBase(ABC):
class ModelConvertCacheBase(ABC):
@property @property
@abstractmethod @abstractmethod
def max_size(self) -> float: def max_size(self) -> float:
@ -25,4 +25,3 @@ class ModelConvertCacheBase(ABC):
def cache_path(self, key: str) -> Path: def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key.""" """Return the path for a model with the indicated key."""
pass pass

View File

@ -2,15 +2,17 @@
Placeholder for convert cache implementation. Placeholder for convert cache implementation.
""" """
from pathlib import Path
import shutil import shutil
from invokeai.backend.util.logging import InvokeAILogger from pathlib import Path
from invokeai.backend.util import GIG, directory_size from invokeai.backend.util import GIG, directory_size
from invokeai.backend.util.logging import InvokeAILogger
from .convert_cache_base import ModelConvertCacheBase from .convert_cache_base import ModelConvertCacheBase
class ModelConvertCache(ModelConvertCacheBase):
def __init__(self, cache_path: Path, max_size: float=10.0): class ModelConvertCache(ModelConvertCacheBase):
def __init__(self, cache_path: Path, max_size: float = 10.0):
"""Initialize the convert cache with the base directory and a limit on its maximum size (in GBs).""" """Initialize the convert cache with the base directory and a limit on its maximum size (in GBs)."""
if not cache_path.exists(): if not cache_path.exists():
cache_path.mkdir(parents=True) cache_path.mkdir(parents=True)

View File

@ -10,17 +10,19 @@ Use like this:
# do something with loaded_model # do something with loaded_model
""" """
import hashlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, Union from typing import Any, Callable, Dict, Optional, Tuple, Type
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLockerBase
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
@dataclass @dataclass
class LoadedModel: class LoadedModel:
@ -52,7 +54,7 @@ class ModelLoaderBase(ABC):
self, self,
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
logger: Logger, logger: Logger,
ram_cache: ModelCacheBase, ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase, convert_cache: ModelConvertCacheBase,
): ):
"""Initialize the loader.""" """Initialize the loader."""
@ -91,7 +93,7 @@ class AnyModelLoader:
self, self,
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
logger: Logger, logger: Logger,
ram_cache: ModelCacheBase, ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase, convert_cache: ModelConvertCacheBase,
): ):
"""Initialize AnyModelLoader with its dependencies.""" """Initialize AnyModelLoader with its dependencies."""
@ -101,11 +103,11 @@ class AnyModelLoader:
self._convert_cache = convert_cache self._convert_cache = convert_cache
@property @property
def ram_cache(self) -> ModelCacheBase: def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache associated used by the loaders.""" """Return the RAM cache associated used by the loaders."""
return self._ram_cache return self._ram_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType]=None) -> LoadedModel: def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
""" """
Return a model given its configuration. Return a model given its configuration.
@ -113,9 +115,7 @@ class AnyModelLoader:
:param submodel_type: an ModelType enum indicating the portion of :param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae) the model to retrieve (e.g. ModelType.Vae)
""" """
implementation = self.__class__.get_implementation( implementation, model_config, submodel_type = self.__class__.get_implementation(model_config, submodel_type)
base=model_config.base, type=model_config.type, format=model_config.format
)
return implementation( return implementation(
app_config=self._app_config, app_config=self._app_config,
logger=self._logger, logger=self._logger,
@ -128,16 +128,37 @@ class AnyModelLoader:
return "-".join([base.value, type.value, format.value]) return "-".join([base.value, type.value, format.value])
@classmethod @classmethod
def get_implementation(cls, base: BaseModelType, type: ModelType, format: ModelFormat) -> Type[ModelLoaderBase]: def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type.""" """Get subclass of ModelLoaderBase registered to handle base and type."""
key1 = cls._to_registry_key(base, type, format) # for a specific base type # We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
key2 = cls._to_registry_key(BaseModelType.Any, type, format) # with wildcard Any conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2) implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation: if not implementation:
raise NotImplementedError( raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}" f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
) )
return implementation return implementation, conf2, submodel_type
@classmethod
def _handle_subtype_overrides(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[AnyModelConfig, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae)
config_class = (
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
)
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
submodel_type = None
else:
new_conf = config
return new_conf, submodel_type
@classmethod @classmethod
def register( def register(
@ -152,4 +173,3 @@ class AnyModelLoader:
return subclass return subclass
return decorator return decorator

View File

@ -10,12 +10,12 @@ from diffusers import ModelMixin
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs, calc_model_size_by_data
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import choose_torch_device, torch_dtype from invokeai.backend.util.devices import choose_torch_device, torch_dtype
@ -38,7 +38,7 @@ class ModelLoader(ModelLoaderBase):
self, self,
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
logger: Logger, logger: Logger,
ram_cache: ModelCacheBase, ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase, convert_cache: ModelConvertCacheBase,
): ):
"""Initialize the loader.""" """Initialize the loader."""
@ -47,7 +47,6 @@ class ModelLoader(ModelLoaderBase):
self._ram_cache = ram_cache self._ram_cache = ram_cache
self._convert_cache = convert_cache self._convert_cache = convert_cache
self._torch_dtype = torch_dtype(choose_torch_device()) self._torch_dtype = torch_dtype(choose_torch_device())
self._size: Optional[int] = None # model size
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
""" """
@ -63,9 +62,7 @@ class ModelLoader(ModelLoaderBase):
if model_config.type == "main" and not submodel_type: if model_config.type == "main" and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model") raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
if is_submodel_override:
submodel_type = None
if not model_path.exists(): if not model_path.exists():
raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}") raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}")
@ -74,13 +71,12 @@ class ModelLoader(ModelLoaderBase):
locker = self._load_if_needed(model_config, model_path, submodel_type) locker = self._load_if_needed(model_config, model_path, submodel_type)
return LoadedModel(config=model_config, locker=locker) return LoadedModel(config=model_config, locker=locker)
# IMPORTANT: This needs to be overridden in the StableDiffusion subclass so as to handle vae overrides
# and submodels!!!!
def _get_model_path( def _get_model_path(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> Tuple[Path, bool]: ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
model_base = self._app_config.models_path model_base = self._app_config.models_path
return ((model_base / config.path).resolve(), False) result = (model_base / config.path).resolve(), config, submodel_type
return result
def _convert_if_needed( def _convert_if_needed(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
@ -90,7 +86,7 @@ class ModelLoader(ModelLoaderBase):
if not self._needs_conversion(config, model_path, cache_path): if not self._needs_conversion(config, model_path, cache_path):
return cache_path if cache_path.exists() else model_path return cache_path if cache_path.exists() else model_path
self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type)) self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
return self._convert_model(config, model_path, cache_path) return self._convert_model(config, model_path, cache_path)
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
@ -114,6 +110,7 @@ class ModelLoader(ModelLoaderBase):
config.key, config.key,
submodel_type=submodel_type, submodel_type=submodel_type,
model=loaded_model, model=loaded_model,
size=calc_model_size_by_data(loaded_model),
) )
return self._ram_cache.get(config.key, submodel_type) return self._ram_cache.get(config.key, submodel_type)
@ -128,17 +125,6 @@ class ModelLoader(ModelLoaderBase):
variant=config.repo_variant if hasattr(config, "repo_variant") else None, variant=config.repo_variant if hasattr(config, "repo_variant") else None,
) )
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
raise NotImplementedError
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
raise NotImplementedError
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]: def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
return ConfigLoader.load_config(model_path, config_name=config_name) return ConfigLoader.load_config(model_path, config_name=config_name)
@ -161,3 +147,17 @@ class ModelLoader(ModelLoaderBase):
config = self._load_diffusers_config(model_path, config_name="config.json") config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config["_class_name"] class_name = config["_class_name"]
return self._hf_definition_to_type(module="diffusers", class_name=class_name) return self._hf_definition_to_type(module="diffusers", class_name=class_name)
# This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
raise NotImplementedError
# This needs to be implemented in the subclass
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
raise NotImplementedError

View File

@ -97,4 +97,4 @@ def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: O
if snapshot_1.vram is not None and snapshot_2.vram is not None: if snapshot_1.vram is not None and snapshot_2.vram is not None:
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
return msg return "\n"+msg if len(msg)>0 else msg

View File

@ -2,4 +2,4 @@
from .model_cache_base import ModelCacheBase from .model_cache_base import ModelCacheBase
from .model_cache_default import ModelCache from .model_cache_default import ModelCache
_all__ = ['ModelCacheBase', 'ModelCache'] _all__ = ["ModelCacheBase", "ModelCache"]

View File

@ -8,14 +8,15 @@ model will be cleared and (re)loaded from disk when next needed.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass
from logging import Logger from logging import Logger
from typing import Dict, Optional, TypeVar, Generic from typing import Generic, Optional, TypeVar
import torch import torch
from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager import AnyModel, SubModelType
class ModelLockerBase(ABC): class ModelLockerBase(ABC):
"""Base class for the model locker used by the loader.""" """Base class for the model locker used by the loader."""
@ -35,8 +36,10 @@ class ModelLockerBase(ABC):
"""Return the model.""" """Return the model."""
pass pass
T = TypeVar("T") T = TypeVar("T")
@dataclass @dataclass
class CacheRecord(Generic[T]): class CacheRecord(Generic[T]):
"""Elements of the cache.""" """Elements of the cache."""
@ -115,6 +118,7 @@ class ModelCacheBase(ABC, Generic[T]):
self, self,
key: str, key: str,
model: T, model: T,
size: int,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> None: ) -> None:
"""Store model under key and optional submodel_type.""" """Store model under key and optional submodel_type."""

View File

@ -19,22 +19,24 @@ context. Use like this:
""" """
import gc import gc
import logging
import math import math
import sys
import time import time
from contextlib import suppress from contextlib import suppress
from logging import Logger from logging import Logger
from typing import Any, Dict, List, Optional from typing import Dict, List, Optional
import torch import torch
from invokeai.backend.model_manager import SubModelType from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModel from invokeai.backend.model_manager.load.load_base import AnyModel
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, ModelCacheBase from .model_cache_base import CacheRecord, ModelCacheBase
from .model_locker import ModelLockerBase, ModelLocker from .model_locker import ModelLocker, ModelLockerBase
if choose_torch_device() == torch.device("mps"): if choose_torch_device() == torch.device("mps"):
from torch import mps from torch import mps
@ -91,7 +93,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._execution_device: torch.device = execution_device self._execution_device: torch.device = execution_device
self._storage_device: torch.device = storage_device self._storage_device: torch.device = storage_device
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = [] self._cache_stack: List[str] = []
@ -141,14 +143,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
self, self,
key: str, key: str,
model: AnyModel, model: AnyModel,
size: int,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> None: ) -> None:
"""Store model under key and optional submodel_type.""" """Store model under key and optional submodel_type."""
key = self._make_cache_key(key, submodel_type) key = self._make_cache_key(key, submodel_type)
assert key not in self._cached_models assert key not in self._cached_models
loaded_model_size = calc_model_size_by_data(model) cache_record = CacheRecord(key, model, size)
cache_record = CacheRecord(key, model, loaded_model_size)
self._cached_models[key] = cache_record self._cached_models[key] = cache_record
self._cache_stack.append(key) self._cache_stack.append(key)
@ -195,28 +197,32 @@ class ModelCache(ModelCacheBase[AnyModel]):
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved: if vram_in_use <= reserved:
break break
if not cache_entry.loaded:
continue
if not cache_entry.locked: if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device) self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM now available for models; max allowed={(reserved/GIG):.2f}GB") self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
torch.cuda.empty_cache() torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"): if choose_torch_device() == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
# TO DO: Only reason to pass the CacheRecord rather than the model is to get the key and size def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
# for printing debugging messages. Revisit whether this is necessary
def move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
"""Move model into the indicated device.""" """Move model into the indicated device."""
# These attributes are not in the base class but in derived classes # These attributes are not in the base ModelMixin class but in derived classes.
assert hasattr(cache_entry.model, "device") # Some models don't have these attributes, in which case they run in RAM/CPU.
assert hasattr(cache_entry.model, "to") self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
return
source_device = cache_entry.model.device source_device = cache_entry.model.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support # Note: We compare device types only so that 'cuda' == 'cuda:0'.
# multi-GPU. # This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type: if torch.device(source_device).type == torch.device(target_device).type:
return return
@ -227,8 +233,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
end_model_to_time = time.time() end_model_to_time = time.time()
self.logger.debug( self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to" f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n" f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n" f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
) )
@ -291,7 +297,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
f" {(bytes_needed/GIG):.2f} GB" f" {(bytes_needed/GIG):.2f} GB"
) )
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}") self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
pos = 0 pos = 0
models_cleared = 0 models_cleared = 0
@ -336,7 +342,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
# 1 from onnx runtime object # 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug( self.logger.debug(
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
) )
current_size -= cache_entry.size current_size -= cache_entry.size
models_cleared += 1 models_cleared += 1
@ -365,4 +371,4 @@ class ModelCache(ModelCacheBase[AnyModel]):
if choose_torch_device() == torch.device("mps"): if choose_torch_device() == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")

View File

@ -2,9 +2,10 @@
Base class and implementation of a class that moves models in and out of VRAM. Base class and implementation of a class that moves models in and out of VRAM.
""" """
from abc import ABC, abstractmethod
from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager import AnyModel
from .model_cache_base import ModelLockerBase, ModelCacheBase, CacheRecord
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
class ModelLocker(ModelLockerBase): class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU.""" """Internal class that mediates movement in and out of GPU."""
@ -56,4 +57,3 @@ class ModelLocker(ModelLockerBase):
if not self._cache.lazy_offloading: if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size) self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.print_cuda_stats() self._cache.print_cuda_stats()

View File

@ -0,0 +1,60 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for ControlNet model loading in InvokeAI."""
from pathlib import Path
import safetensors
import torch
from invokeai.backend.model_manager import (
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from .generic_diffusers import GenericDiffusersLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
class ControlnetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint:
return False
elif (
dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
else:
return True
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
else:
assert hasattr(config, 'config')
config_file = config.config
if weights_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
else:
checkpoint = torch.load(weights_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
convert_controlnet_to_diffusers(
weights_path,
output_path,
original_config_file=self._app_config.root_path / config_file,
image_size=512,
scan_needed=True,
from_safetensors=weights_path.suffix == ".safetensors",
)
return output_path

View File

@ -0,0 +1,34 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for simple diffusers model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager import (
AnyModel,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
class GenericDiffusersLoader(ModelLoader):
"""Class to load simple diffusers models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
model_class = self._get_hf_load_class(model_path)
if submodel_type is not None:
raise Exception(f"There are no submodels in models of type {model_class}")
variant = model_variant.value if model_variant else None
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
return result

View File

@ -0,0 +1,39 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for IP Adapter model loading in InvokeAI."""
import torch
from pathlib import Path
from typing import Optional
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
from invokeai.backend.model_manager import (
AnyModel,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
class IPAdapterInvokeAILoader(ModelLoader):
"""Class to load IP Adapter diffusers models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("There are no submodels in an IP-Adapter model.")
model = build_ip_adapter(
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
device=torch.device("cpu"),
dtype=self._torch_dtype,
)
return model

View File

@ -0,0 +1,76 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for LoRA model loading in InvokeAI."""
from pathlib import Path
from typing import Optional, Tuple
from logging import Logger
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.lora import LoRAModelRaw
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
class LoraLoader(ModelLoader):
"""Class to load LoRA models."""
# We cheat a little bit to get access to the model base
def __init__(
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize the loader."""
super().__init__(app_config, logger, ram_cache, convert_cache)
self._model_base: Optional[BaseModelType] = None
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("There are no submodels in a LoRA model.")
model = LoRAModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
base_model=self._model_base,
)
return model
# override
def _get_model_path(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
self._model_base = config.base # cheating a little - setting this variable for later call to _load_model()
model_base_path = self._app_config.models_path
model_path = model_base_path / config.path
if config.format == ModelFormat.Diffusers:
for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder
path = model_base_path / config.path / f"pytorch_lora_weights.{ext}"
if path.exists():
model_path = path
break
result = model_path.resolve(), config, submodel_type
return result

View File

@ -0,0 +1,93 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for StableDiffusion model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
ModelVariantType,
SubModelType,
)
from invokeai.backend.model_manager.config import MainCheckpointConfig
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
class StableDiffusionDiffusersModel(ModelLoader):
"""Class to load main models."""
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
}
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not submodel_type is not None:
raise Exception("A submodel type must be provided when loading main pipelines.")
load_class = self._get_hf_load_class(model_path, submodel_type)
variant = model_variant.value if model_variant else None
model_path = model_path / submodel_type.value
result: AnyModel = load_class.from_pretrained(
model_path,
torch_dtype=self._torch_dtype,
variant=variant,
) # type: ignore
return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint:
return False
elif (
dest_path.exists()
and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
else:
return True
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
assert isinstance(config, MainCheckpointConfig)
variant = config.variant
base = config.base
pipeline_class = (
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
)
config_file = config.config
self._logger.info(f"Converting {weights_path} to diffusers format")
convert_ckpt_to_diffusers(
weights_path,
output_path,
model_type=self.model_base_to_model_type[base],
model_version=base,
model_variant=variant,
original_config_file=self._app_config.root_path / config_file,
extract_ema=True,
scan_needed=True,
pipeline_class=pipeline_class,
from_safetensors=weights_path.suffix == ".safetensors",
precision=self._torch_dtype,
load_safety_checker=False,
)
return output_path

View File

@ -2,68 +2,54 @@
"""Class for VAE model loading in InvokeAI.""" """Class for VAE model loading in InvokeAI."""
from pathlib import Path from pathlib import Path
from typing import Optional
import torch
import safetensors import safetensors
from omegaconf import OmegaConf, DictConfig import torch
from invokeai.backend.util.devices import torch_dtype from omegaconf import DictConfig, OmegaConf
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModelLoader from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.load.load_default import ModelLoader AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from .generic_diffusers import GenericDiffusersLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) @AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) @AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
class VaeDiffusersModel(ModelLoader): class VaeLoader(GenericDiffusersLoader):
"""Class to load VAE models.""" """Class to load VAE models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise Exception("There are no submodels in VAEs")
vae_class = self._get_hf_load_class(model_path)
variant = model_variant.value if model_variant else None
result: AnyModel = vae_class.from_pretrained(
model_path, torch_dtype=self._torch_dtype, variant=variant
) # type: ignore
return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
print(f'DEBUG: last_modified={config.last_modified}')
print(f'DEBUG: cache_path={(dest_path / "config.json").stat().st_mtime}')
print(f'DEBUG: model_path={model_path.stat().st_mtime}')
if config.format != ModelFormat.Checkpoint: if config.format != ModelFormat.Checkpoint:
return False return False
elif dest_path.exists() \ elif (
and (dest_path / "config.json").stat().st_mtime >= config.last_modified \ dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime: and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False return False
else: else:
return True return True
def _convert_model(self, def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
config: AnyModelConfig, # TO DO: check whether sdxl VAE models convert.
weights_path: Path,
output_path: Path
) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}") raise Exception(f"Vae conversion not supported for model type: {config.base}")
else: else:
config_file = 'v1-inference.yaml' if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" config_file = (
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
)
if weights_path.suffix == ".safetensors": if weights_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(weights_path, device="cpu") checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
else: else:
checkpoint = torch.load(weights_path, map_location="cpu") checkpoint = torch.load(weights_path, map_location="cpu")
dtype = torch_dtype()
# sometimes weights are hidden under "state_dict", and sometimes not # sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint: if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
@ -71,13 +57,11 @@ class VaeDiffusersModel(ModelLoader):
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file) ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
assert isinstance(ckpt_config, DictConfig) assert isinstance(ckpt_config, DictConfig)
print(f'DEBUG: CONVERTIGN')
vae_model = convert_ldm_vae_to_diffusers( vae_model = convert_ldm_vae_to_diffusers(
checkpoint=checkpoint, checkpoint=checkpoint,
vae_config=ckpt_config, vae_config=ckpt_config,
image_size=512, image_size=512,
) )
vae_model.to(dtype) # set precision appropriately vae_model.to(self._torch_dtype) # set precision appropriately
vae_model.save_pretrained(output_path, safe_serialization=True, torch_dtype=dtype) vae_model.save_pretrained(output_path, safe_serialization=True)
return output_path return output_path

View File

@ -8,10 +8,11 @@ from typing import Optional, Union
import torch import torch
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel
def calc_model_size_by_data(model: Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel]) -> int: def calc_model_size_by_data(model: AnyModel) -> int:
"""Get size of a model in memory in bytes.""" """Get size of a model in memory in bytes."""
if isinstance(model, DiffusionPipeline): if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model) return _calc_pipeline_by_data(model)

View File

@ -0,0 +1,620 @@
# Copyright (c) 2024 The InvokeAI Development team
"""LoRA model support."""
import torch
from safetensors.torch import load_file
from pathlib import Path
from typing import Dict, Optional, Union, List, Tuple
from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
else:
self.mid = None
self.rank = self.down.shape[0]
def get_weight(self, orig_weight: torch.Tensor):
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor]
):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1: Optional[torch.Tensor] = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2: Optional[torch.Tensor] = values["hada_t2"]
else:
self.t2 = None
self.rank = self.w1_b.shape[0]
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
else:
self.t2 = None
if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
else:
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
assert self.w1_b is not None
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
assert self.w2_a is not None
assert self.w2_b is not None
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
assert w1 is not None
assert w2 is not None
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class FullLayer(LoRALayerBase):
# weight: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["diff"]
if len(values.keys()) > 1:
_keys = list(values.keys())
_keys.remove("diff")
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor):
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
_name: str
layers: Dict[str, AnyLoRALayer]
def __init__(
self,
name: str,
layers: Dict[str, AnyLoRALayer],
):
self._name = name
self.layers = layers
@property
def name(self) -> str:
return self._name
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
# TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
) -> Self:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
name=file_path.stem, # TODO:
layers={},
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "weight" in values and "on_input" in values:
layer = IA3Layer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map() -> List[Tuple[str,str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
}

View File

@ -29,8 +29,12 @@ CkptType = Dict[str, Any]
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = { LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
BaseModelType.StableDiffusion1: { BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml", ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v1-inference.yaml",
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
},
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
}, },
BaseModelType.StableDiffusion2: { BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: { ModelVariantType.Normal: {

View File

@ -20,6 +20,14 @@ from .util import ( # TO DO: Clean this up; remove the unused symbols
download_with_resume, download_with_resume,
instantiate_from_config, # noqa instantiate_from_config, # noqa
url_attachment_name, # noqa url_attachment_name, # noqa
) )
__all__ = ["GIG", "directory_size","Chdir", "download_with_resume", "InvokeAILogger", "choose_precision", "choose_torch_device"] __all__ = [
"GIG",
"directory_size",
"Chdir",
"download_with_resume",
"InvokeAILogger",
"choose_precision",
"choose_torch_device",
]

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import nullcontext from contextlib import nullcontext
from typing import Union, Optional from typing import Optional, Union
import torch import torch
from torch import autocast from torch import autocast

View File

@ -27,6 +27,7 @@ from .devices import torch_dtype
# actual size of a gig # actual size of a gig
GIG = 1073741824 GIG = 1073741824
def directory_size(directory: Path) -> int: def directory_size(directory: Path) -> int:
""" """
Return the aggregate size of all files in a directory (bytes). Return the aggregate size of all files in a directory (bytes).
@ -39,6 +40,7 @@ def directory_size(directory: Path) -> int:
sum += Path(root, d).stat().st_size sum += Path(root, d).stat().st_size
return sum return sum
def log_txt_as_img(wh, xc, size=10): def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height) # wh a tuple of (width, height)
# xc a list of captions to plot # xc a list of captions to plot