mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
loaders for main, controlnet, ip-adapter, clipvision and t2i
This commit is contained in:
parent
9804cb0e67
commit
420f6050a6
@ -4,10 +4,10 @@ from logging import Logger
|
|||||||
|
|
||||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache
|
from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache
|
||||||
from invokeai.backend.model_manager.load.model_cache import ModelCache
|
from invokeai.backend.model_manager.load.model_cache import ModelCache
|
||||||
|
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.version.invokeai_version import __version__
|
from invokeai.version.invokeai_version import __version__
|
||||||
|
|
||||||
from ..services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
|
from ..services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
|
||||||
@ -90,15 +90,14 @@ class ApiDependencies:
|
|||||||
model_loader = AnyModelLoader(
|
model_loader = AnyModelLoader(
|
||||||
app_config=config,
|
app_config=config,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
ram_cache=ModelCache(max_cache_size=config.ram_cache_size,
|
ram_cache=ModelCache(
|
||||||
max_vram_cache_size=config.vram_cache_size,
|
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||||
logger=logger),
|
),
|
||||||
convert_cache=ModelConvertCache(
|
convert_cache=ModelConvertCache(
|
||||||
cache_path = config.models_convert_cache_path,
|
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||||
max_size = config.convert_cache_size
|
),
|
||||||
)
|
)
|
||||||
)
|
model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader)
|
||||||
model_record_service = ModelRecordServiceSQL(db=db,loader=model_loader)
|
|
||||||
download_queue_service = DownloadQueueService(event_bus=events)
|
download_queue_service = DownloadQueueService(event_bus=events)
|
||||||
model_install_service = ModelInstallService(
|
model_install_service = ModelInstallService(
|
||||||
app_config=config,
|
app_config=config,
|
||||||
|
@ -173,7 +173,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
from typing import Any, ClassVar, Dict, List, Literal, Optional, get_type_hints
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
|
@ -11,7 +11,14 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
from invokeai.backend.model_manager import LoadedModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
from invokeai.backend.model_manager import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
LoadedModel,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,7 +42,6 @@ Typical usage:
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import time
|
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||||
@ -56,8 +55,8 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
|
|
||||||
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel
|
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel
|
||||||
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
|
||||||
|
|
||||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from .model_records_base import (
|
from .model_records_base import (
|
||||||
@ -72,7 +71,7 @@ from .model_records_base import (
|
|||||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader]=None):
|
def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader] = None):
|
||||||
"""
|
"""
|
||||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||||
|
|
||||||
@ -289,7 +288,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
""",
|
""",
|
||||||
tuple(bindings),
|
tuple(bindings),
|
||||||
)
|
)
|
||||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
|
results = [
|
||||||
|
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||||
|
]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||||
@ -303,7 +304,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
""",
|
""",
|
||||||
(str(path),),
|
(str(path),),
|
||||||
)
|
)
|
||||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
|
results = [
|
||||||
|
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||||
|
]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||||
@ -317,7 +320,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
""",
|
""",
|
||||||
(hash,),
|
(hash,),
|
||||||
)
|
)
|
||||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
|
results = [
|
||||||
|
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||||
|
]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from logging import Logger
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||||
|
|
||||||
class Migration6Callback:
|
|
||||||
|
|
||||||
|
class Migration6Callback:
|
||||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
self._recreate_model_triggers(cursor)
|
self._recreate_model_triggers(cursor)
|
||||||
|
|
||||||
@ -28,6 +26,7 @@ class Migration6Callback:
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_migration_6() -> Migration:
|
def build_migration_6() -> Migration:
|
||||||
"""
|
"""
|
||||||
Build the migration from database version 5 to 6.
|
Build the migration from database version 5 to 6.
|
||||||
|
@ -120,6 +120,7 @@ class TqdmEventService(EventServiceBase):
|
|||||||
elif payload["event"] == "model_install_cancelled":
|
elif payload["event"] == "model_install_cancelled":
|
||||||
self._logger.warning(f"{source}: installation cancelled")
|
self._logger.warning(f"{source}: installation cancelled")
|
||||||
|
|
||||||
|
|
||||||
class InstallHelper(object):
|
class InstallHelper(object):
|
||||||
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""
|
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""
|
||||||
|
|
||||||
|
@ -139,7 +139,6 @@ def _convert_controlnet_ckpt_and_cache(
|
|||||||
cache it to disk, and return Path to converted
|
cache it to disk, and return Path to converted
|
||||||
file. If already on disk then just returns Path.
|
file. If already on disk then just returns Path.
|
||||||
"""
|
"""
|
||||||
print(f"DEBUG: controlnet config = {model_config}")
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
weights = app_config.root_path / model_path
|
weights = app_config.root_path / model_path
|
||||||
output_path = Path(output_path)
|
output_path = Path(output_path)
|
||||||
|
@ -13,9 +13,9 @@ from .config import (
|
|||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
|
from .load import LoadedModel
|
||||||
from .probe import ModelProbe
|
from .probe import ModelProbe
|
||||||
from .search import ModelSearch
|
from .search import ModelSearch
|
||||||
from .load import LoadedModel
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnyModel",
|
"AnyModel",
|
||||||
|
@ -20,14 +20,16 @@ Validation errors will raise an InvalidModelConfigException error.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
import torch
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal, Optional, Type, Union
|
from typing import Literal, Optional, Type, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
import torch
|
||||||
from diffusers import ModelMixin
|
from diffusers import ModelMixin
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||||
from typing_extensions import Annotated, Any, Dict
|
from typing_extensions import Annotated, Any, Dict
|
||||||
|
|
||||||
from .onnx_runtime import IAIOnnxRuntimeModel
|
from .onnx_runtime import IAIOnnxRuntimeModel
|
||||||
|
from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
|
|
||||||
class InvalidModelConfigException(Exception):
|
class InvalidModelConfigException(Exception):
|
||||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||||
@ -204,6 +206,8 @@ class _MainConfig(ModelConfigBase):
|
|||||||
|
|
||||||
vae: Optional[str] = Field(default=None)
|
vae: Optional[str] = Field(default=None)
|
||||||
variant: ModelVariantType = ModelVariantType.Normal
|
variant: ModelVariantType = ModelVariantType.Normal
|
||||||
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||||
|
upcast_attention: bool = False
|
||||||
ztsnr_training: bool = False
|
ztsnr_training: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -217,8 +221,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
|||||||
"""Model config for main diffusers models."""
|
"""Model config for main diffusers models."""
|
||||||
|
|
||||||
type: Literal[ModelType.Main] = ModelType.Main
|
type: Literal[ModelType.Main] = ModelType.Main
|
||||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
|
||||||
upcast_attention: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXSD1Config(_MainConfig):
|
class ONNXSD1Config(_MainConfig):
|
||||||
@ -276,6 +278,7 @@ AnyModelConfig = Union[
|
|||||||
_ONNXConfig,
|
_ONNXConfig,
|
||||||
_VaeConfig,
|
_VaeConfig,
|
||||||
_ControlNetConfig,
|
_ControlNetConfig,
|
||||||
|
# ModelConfigBase,
|
||||||
LoRAConfig,
|
LoRAConfig,
|
||||||
TextualInversionConfig,
|
TextualInversionConfig,
|
||||||
IPAdapterConfig,
|
IPAdapterConfig,
|
||||||
@ -284,7 +287,7 @@ AnyModelConfig = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||||
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel]
|
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus]
|
||||||
|
|
||||||
# IMPLEMENTATION NOTE:
|
# IMPLEMENTATION NOTE:
|
||||||
# The preferred alternative to the above is a discriminated Union as shown
|
# The preferred alternative to the above is a discriminated Union as shown
|
||||||
@ -317,7 +320,7 @@ class ModelConfigFactory(object):
|
|||||||
model_data: Union[dict, AnyModelConfig],
|
model_data: Union[dict, AnyModelConfig],
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
dest_class: Optional[Type] = None,
|
dest_class: Optional[Type] = None,
|
||||||
timestamp: Optional[float] = None
|
timestamp: Optional[float] = None,
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Return the appropriate config object from raw dict values.
|
Return the appropriate config object from raw dict values.
|
||||||
|
@ -43,7 +43,6 @@ from diffusers.schedulers import (
|
|||||||
UnCLIPScheduler,
|
UnCLIPScheduler,
|
||||||
)
|
)
|
||||||
from diffusers.utils import is_accelerate_available
|
from diffusers.utils import is_accelerate_available
|
||||||
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoFeatureExtractor,
|
AutoFeatureExtractor,
|
||||||
@ -58,8 +57,8 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -1643,7 +1642,6 @@ def download_controlnet_from_original_ckpt(
|
|||||||
cross_attention_dim: Optional[bool] = None,
|
cross_attention_dim: Optional[bool] = None,
|
||||||
scan_needed: bool = False,
|
scan_needed: bool = False,
|
||||||
) -> DiffusionPipeline:
|
) -> DiffusionPipeline:
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
if from_safetensors:
|
if from_safetensors:
|
||||||
|
@ -8,14 +8,15 @@ from typing import Optional
|
|||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||||
from .load_base import AnyModelLoader, LoadedModel
|
from .load_base import AnyModelLoader, LoadedModel
|
||||||
from .model_cache.model_cache_default import ModelCache
|
from .model_cache.model_cache_default import ModelCache
|
||||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
|
||||||
|
|
||||||
# This registers the subclasses that implement loaders of specific model types
|
# This registers the subclasses that implement loaders of specific model types
|
||||||
loaders = [x.stem for x in Path(Path(__file__).parent,'model_loaders').glob('*.py') if x.stem != '__init__']
|
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
|
||||||
for module in loaders:
|
for module in loaders:
|
||||||
print(f'module={module}')
|
print(f"module={module}")
|
||||||
import_module(f"{__package__}.model_loaders.{module}")
|
import_module(f"{__package__}.model_loaders.{module}")
|
||||||
|
|
||||||
__all__ = ["AnyModelLoader", "LoadedModel"]
|
__all__ = ["AnyModelLoader", "LoadedModel"]
|
||||||
@ -24,12 +25,11 @@ __all__ = ["AnyModelLoader", "LoadedModel"]
|
|||||||
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
|
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
|
||||||
app_config = app_config or InvokeAIAppConfig.get_config()
|
app_config = app_config or InvokeAIAppConfig.get_config()
|
||||||
logger = InvokeAILogger.get_logger(config=app_config)
|
logger = InvokeAILogger.get_logger(config=app_config)
|
||||||
return AnyModelLoader(app_config=app_config,
|
return AnyModelLoader(
|
||||||
|
app_config=app_config,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
ram_cache=ModelCache(logger=logger,
|
ram_cache=ModelCache(
|
||||||
max_cache_size=app_config.ram_cache_size,
|
logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size
|
||||||
max_vram_cache_size=app_config.vram_cache_size
|
|
||||||
),
|
),
|
||||||
convert_cache=ModelConvertCache(app_config.models_convert_cache_path)
|
convert_cache=ModelConvertCache(app_config.models_convert_cache_path),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .convert_cache_base import ModelConvertCacheBase
|
from .convert_cache_base import ModelConvertCacheBase
|
||||||
from .convert_cache_default import ModelConvertCache
|
from .convert_cache_default import ModelConvertCache
|
||||||
|
|
||||||
__all__ = ['ModelConvertCacheBase', 'ModelConvertCache']
|
__all__ = ["ModelConvertCacheBase", "ModelConvertCache"]
|
||||||
|
@ -4,8 +4,8 @@ Disk-based converted model cache.
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
class ModelConvertCacheBase(ABC):
|
|
||||||
|
|
||||||
|
class ModelConvertCacheBase(ABC):
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def max_size(self) -> float:
|
def max_size(self) -> float:
|
||||||
@ -25,4 +25,3 @@ class ModelConvertCacheBase(ABC):
|
|||||||
def cache_path(self, key: str) -> Path:
|
def cache_path(self, key: str) -> Path:
|
||||||
"""Return the path for a model with the indicated key."""
|
"""Return the path for a model with the indicated key."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -2,15 +2,17 @@
|
|||||||
Placeholder for convert cache implementation.
|
Placeholder for convert cache implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
import shutil
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from pathlib import Path
|
||||||
|
|
||||||
from invokeai.backend.util import GIG, directory_size
|
from invokeai.backend.util import GIG, directory_size
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .convert_cache_base import ModelConvertCacheBase
|
from .convert_cache_base import ModelConvertCacheBase
|
||||||
|
|
||||||
class ModelConvertCache(ModelConvertCacheBase):
|
|
||||||
|
|
||||||
def __init__(self, cache_path: Path, max_size: float=10.0):
|
class ModelConvertCache(ModelConvertCacheBase):
|
||||||
|
def __init__(self, cache_path: Path, max_size: float = 10.0):
|
||||||
"""Initialize the convert cache with the base directory and a limit on its maximum size (in GBs)."""
|
"""Initialize the convert cache with the base directory and a limit on its maximum size (in GBs)."""
|
||||||
if not cache_path.exists():
|
if not cache_path.exists():
|
||||||
cache_path.mkdir(parents=True)
|
cache_path.mkdir(parents=True)
|
||||||
|
@ -10,17 +10,19 @@ Use like this:
|
|||||||
# do something with loaded_model
|
# do something with loaded_model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Optional, Type, Union
|
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLockerBase
|
|
||||||
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
|
||||||
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
@ -52,7 +54,7 @@ class ModelLoaderBase(ABC):
|
|||||||
self,
|
self,
|
||||||
app_config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
ram_cache: ModelCacheBase,
|
ram_cache: ModelCacheBase[AnyModel],
|
||||||
convert_cache: ModelConvertCacheBase,
|
convert_cache: ModelConvertCacheBase,
|
||||||
):
|
):
|
||||||
"""Initialize the loader."""
|
"""Initialize the loader."""
|
||||||
@ -91,7 +93,7 @@ class AnyModelLoader:
|
|||||||
self,
|
self,
|
||||||
app_config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
ram_cache: ModelCacheBase,
|
ram_cache: ModelCacheBase[AnyModel],
|
||||||
convert_cache: ModelConvertCacheBase,
|
convert_cache: ModelConvertCacheBase,
|
||||||
):
|
):
|
||||||
"""Initialize AnyModelLoader with its dependencies."""
|
"""Initialize AnyModelLoader with its dependencies."""
|
||||||
@ -101,11 +103,11 @@ class AnyModelLoader:
|
|||||||
self._convert_cache = convert_cache
|
self._convert_cache = convert_cache
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ram_cache(self) -> ModelCacheBase:
|
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||||
"""Return the RAM cache associated used by the loaders."""
|
"""Return the RAM cache associated used by the loaders."""
|
||||||
return self._ram_cache
|
return self._ram_cache
|
||||||
|
|
||||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType]=None) -> LoadedModel:
|
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
Return a model given its configuration.
|
Return a model given its configuration.
|
||||||
|
|
||||||
@ -113,9 +115,7 @@ class AnyModelLoader:
|
|||||||
:param submodel_type: an ModelType enum indicating the portion of
|
:param submodel_type: an ModelType enum indicating the portion of
|
||||||
the model to retrieve (e.g. ModelType.Vae)
|
the model to retrieve (e.g. ModelType.Vae)
|
||||||
"""
|
"""
|
||||||
implementation = self.__class__.get_implementation(
|
implementation, model_config, submodel_type = self.__class__.get_implementation(model_config, submodel_type)
|
||||||
base=model_config.base, type=model_config.type, format=model_config.format
|
|
||||||
)
|
|
||||||
return implementation(
|
return implementation(
|
||||||
app_config=self._app_config,
|
app_config=self._app_config,
|
||||||
logger=self._logger,
|
logger=self._logger,
|
||||||
@ -128,16 +128,37 @@ class AnyModelLoader:
|
|||||||
return "-".join([base.value, type.value, format.value])
|
return "-".join([base.value, type.value, format.value])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_implementation(cls, base: BaseModelType, type: ModelType, format: ModelFormat) -> Type[ModelLoaderBase]:
|
def get_implementation(
|
||||||
|
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||||
|
) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]:
|
||||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||||
key1 = cls._to_registry_key(base, type, format) # for a specific base type
|
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||||
key2 = cls._to_registry_key(BaseModelType.Any, type, format) # with wildcard Any
|
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||||
|
|
||||||
|
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
|
||||||
|
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
|
||||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||||
if not implementation:
|
if not implementation:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}"
|
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||||
)
|
)
|
||||||
return implementation
|
return implementation, conf2, submodel_type
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _handle_subtype_overrides(
|
||||||
|
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||||
|
) -> Tuple[AnyModelConfig, Optional[SubModelType]]:
|
||||||
|
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||||
|
model_path = Path(config.vae)
|
||||||
|
config_class = (
|
||||||
|
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
|
||||||
|
)
|
||||||
|
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
|
||||||
|
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
|
||||||
|
submodel_type = None
|
||||||
|
else:
|
||||||
|
new_conf = config
|
||||||
|
return new_conf, submodel_type
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(
|
def register(
|
||||||
@ -152,4 +173,3 @@ class AnyModelLoader:
|
|||||||
return subclass
|
return subclass
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
@ -10,12 +10,12 @@ from diffusers import ModelMixin
|
|||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase
|
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
|
||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||||
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs, calc_model_size_by_data
|
||||||
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||||
|
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
self,
|
self,
|
||||||
app_config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
ram_cache: ModelCacheBase,
|
ram_cache: ModelCacheBase[AnyModel],
|
||||||
convert_cache: ModelConvertCacheBase,
|
convert_cache: ModelConvertCacheBase,
|
||||||
):
|
):
|
||||||
"""Initialize the loader."""
|
"""Initialize the loader."""
|
||||||
@ -47,7 +47,6 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
self._ram_cache = ram_cache
|
self._ram_cache = ram_cache
|
||||||
self._convert_cache = convert_cache
|
self._convert_cache = convert_cache
|
||||||
self._torch_dtype = torch_dtype(choose_torch_device())
|
self._torch_dtype = torch_dtype(choose_torch_device())
|
||||||
self._size: Optional[int] = None # model size
|
|
||||||
|
|
||||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
@ -63,9 +62,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
if model_config.type == "main" and not submodel_type:
|
if model_config.type == "main" and not submodel_type:
|
||||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||||
|
|
||||||
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
||||||
if is_submodel_override:
|
|
||||||
submodel_type = None
|
|
||||||
|
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}")
|
raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}")
|
||||||
@ -74,13 +71,12 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
locker = self._load_if_needed(model_config, model_path, submodel_type)
|
locker = self._load_if_needed(model_config, model_path, submodel_type)
|
||||||
return LoadedModel(config=model_config, locker=locker)
|
return LoadedModel(config=model_config, locker=locker)
|
||||||
|
|
||||||
# IMPORTANT: This needs to be overridden in the StableDiffusion subclass so as to handle vae overrides
|
|
||||||
# and submodels!!!!
|
|
||||||
def _get_model_path(
|
def _get_model_path(
|
||||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||||
) -> Tuple[Path, bool]:
|
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||||
model_base = self._app_config.models_path
|
model_base = self._app_config.models_path
|
||||||
return ((model_base / config.path).resolve(), False)
|
result = (model_base / config.path).resolve(), config, submodel_type
|
||||||
|
return result
|
||||||
|
|
||||||
def _convert_if_needed(
|
def _convert_if_needed(
|
||||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||||
@ -90,7 +86,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
if not self._needs_conversion(config, model_path, cache_path):
|
if not self._needs_conversion(config, model_path, cache_path):
|
||||||
return cache_path if cache_path.exists() else model_path
|
return cache_path if cache_path.exists() else model_path
|
||||||
|
|
||||||
self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type))
|
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||||
return self._convert_model(config, model_path, cache_path)
|
return self._convert_model(config, model_path, cache_path)
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
|
||||||
@ -114,6 +110,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
config.key,
|
config.key,
|
||||||
submodel_type=submodel_type,
|
submodel_type=submodel_type,
|
||||||
model=loaded_model,
|
model=loaded_model,
|
||||||
|
size=calc_model_size_by_data(loaded_model),
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._ram_cache.get(config.key, submodel_type)
|
return self._ram_cache.get(config.key, submodel_type)
|
||||||
@ -128,17 +125,6 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _load_model(
|
|
||||||
self,
|
|
||||||
model_path: Path,
|
|
||||||
model_variant: Optional[ModelRepoVariant] = None,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
|
||||||
) -> AnyModel:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
||||||
return ConfigLoader.load_config(model_path, config_name=config_name)
|
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||||
|
|
||||||
@ -161,3 +147,17 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||||
class_name = config["_class_name"]
|
class_name = config["_class_name"]
|
||||||
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||||
|
|
||||||
|
# This needs to be implemented in subclasses that handle checkpoints
|
||||||
|
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# This needs to be implemented in the subclass
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
model_path: Path,
|
||||||
|
model_variant: Optional[ModelRepoVariant] = None,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -97,4 +97,4 @@ def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: O
|
|||||||
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
||||||
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
|
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
|
||||||
|
|
||||||
return msg
|
return "\n"+msg if len(msg)>0 else msg
|
||||||
|
@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
from .model_cache_base import ModelCacheBase
|
from .model_cache_base import ModelCacheBase
|
||||||
from .model_cache_default import ModelCache
|
from .model_cache_default import ModelCache
|
||||||
_all__ = ['ModelCacheBase', 'ModelCache']
|
_all__ = ["ModelCacheBase", "ModelCache"]
|
||||||
|
@ -8,14 +8,15 @@ model will be cleared and (re)loaded from disk when next needed.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import Dict, Optional, TypeVar, Generic
|
from typing import Generic, Optional, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||||
|
|
||||||
|
|
||||||
class ModelLockerBase(ABC):
|
class ModelLockerBase(ABC):
|
||||||
"""Base class for the model locker used by the loader."""
|
"""Base class for the model locker used by the loader."""
|
||||||
|
|
||||||
@ -35,8 +36,10 @@ class ModelLockerBase(ABC):
|
|||||||
"""Return the model."""
|
"""Return the model."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CacheRecord(Generic[T]):
|
class CacheRecord(Generic[T]):
|
||||||
"""Elements of the cache."""
|
"""Elements of the cache."""
|
||||||
@ -115,6 +118,7 @@ class ModelCacheBase(ABC, Generic[T]):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
model: T,
|
model: T,
|
||||||
|
size: int,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store model under key and optional submodel_type."""
|
"""Store model under key and optional submodel_type."""
|
||||||
|
@ -19,22 +19,24 @@ context. Use like this:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.model_manager import SubModelType
|
from invokeai.backend.model_manager import SubModelType
|
||||||
from invokeai.backend.model_manager.load.load_base import AnyModel
|
from invokeai.backend.model_manager.load.load_base import AnyModel
|
||||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .model_cache_base import CacheRecord, ModelCacheBase
|
from .model_cache_base import CacheRecord, ModelCacheBase
|
||||||
from .model_locker import ModelLockerBase, ModelLocker
|
from .model_locker import ModelLocker, ModelLockerBase
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
from torch import mps
|
from torch import mps
|
||||||
@ -91,7 +93,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self._execution_device: torch.device = execution_device
|
self._execution_device: torch.device = execution_device
|
||||||
self._storage_device: torch.device = storage_device
|
self._storage_device: torch.device = storage_device
|
||||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||||
self._log_memory_usage = log_memory_usage
|
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
|
||||||
|
|
||||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||||
self._cache_stack: List[str] = []
|
self._cache_stack: List[str] = []
|
||||||
@ -141,14 +143,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
model: AnyModel,
|
model: AnyModel,
|
||||||
|
size: int,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store model under key and optional submodel_type."""
|
"""Store model under key and optional submodel_type."""
|
||||||
key = self._make_cache_key(key, submodel_type)
|
key = self._make_cache_key(key, submodel_type)
|
||||||
assert key not in self._cached_models
|
assert key not in self._cached_models
|
||||||
|
|
||||||
loaded_model_size = calc_model_size_by_data(model)
|
cache_record = CacheRecord(key, model, size)
|
||||||
cache_record = CacheRecord(key, model, loaded_model_size)
|
|
||||||
self._cached_models[key] = cache_record
|
self._cached_models[key] = cache_record
|
||||||
self._cache_stack.append(key)
|
self._cache_stack.append(key)
|
||||||
|
|
||||||
@ -195,28 +197,32 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||||
if vram_in_use <= reserved:
|
if vram_in_use <= reserved:
|
||||||
break
|
break
|
||||||
|
if not cache_entry.loaded:
|
||||||
|
continue
|
||||||
if not cache_entry.locked:
|
if not cache_entry.locked:
|
||||||
self.move_model_to_device(cache_entry, self.storage_device)
|
self.move_model_to_device(cache_entry, self.storage_device)
|
||||||
cache_entry.loaded = False
|
cache_entry.loaded = False
|
||||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM now available for models; max allowed={(reserved/GIG):.2f}GB")
|
self.logger.debug(
|
||||||
|
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||||
|
)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
# TO DO: Only reason to pass the CacheRecord rather than the model is to get the key and size
|
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||||
# for printing debugging messages. Revisit whether this is necessary
|
|
||||||
def move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
|
|
||||||
"""Move model into the indicated device."""
|
"""Move model into the indicated device."""
|
||||||
# These attributes are not in the base class but in derived classes
|
# These attributes are not in the base ModelMixin class but in derived classes.
|
||||||
assert hasattr(cache_entry.model, "device")
|
# Some models don't have these attributes, in which case they run in RAM/CPU.
|
||||||
assert hasattr(cache_entry.model, "to")
|
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||||
|
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
||||||
|
return
|
||||||
|
|
||||||
source_device = cache_entry.model.device
|
source_device = cache_entry.model.device
|
||||||
|
|
||||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support
|
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||||
# multi-GPU.
|
# This would need to be revised to support multi-GPU.
|
||||||
if torch.device(source_device).type == torch.device(target_device).type:
|
if torch.device(source_device).type == torch.device(target_device).type:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -227,8 +233,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
end_model_to_time = time.time()
|
end_model_to_time = time.time()
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n"
|
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n"
|
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -291,7 +297,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
f" {(bytes_needed/GIG):.2f} GB"
|
f" {(bytes_needed/GIG):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||||
|
|
||||||
pos = 0
|
pos = 0
|
||||||
models_cleared = 0
|
models_cleared = 0
|
||||||
@ -336,7 +342,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
# 1 from onnx runtime object
|
# 1 from onnx runtime object
|
||||||
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||||
)
|
)
|
||||||
current_size -= cache_entry.size
|
current_size -= cache_entry.size
|
||||||
models_cleared += 1
|
models_cleared += 1
|
||||||
@ -365,4 +371,4 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||||
|
@ -2,9 +2,10 @@
|
|||||||
Base class and implementation of a class that moves models in and out of VRAM.
|
Base class and implementation of a class that moves models in and out of VRAM.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from invokeai.backend.model_manager import AnyModel
|
from invokeai.backend.model_manager import AnyModel
|
||||||
from .model_cache_base import ModelLockerBase, ModelCacheBase, CacheRecord
|
|
||||||
|
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
|
||||||
|
|
||||||
|
|
||||||
class ModelLocker(ModelLockerBase):
|
class ModelLocker(ModelLockerBase):
|
||||||
"""Internal class that mediates movement in and out of GPU."""
|
"""Internal class that mediates movement in and out of GPU."""
|
||||||
@ -56,4 +57,3 @@ class ModelLocker(ModelLockerBase):
|
|||||||
if not self._cache.lazy_offloading:
|
if not self._cache.lazy_offloading:
|
||||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
|
|
||||||
|
@ -0,0 +1,60 @@
|
|||||||
|
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""Class for ControlNet model loading in InvokeAI."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||||
|
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||||
|
from .generic_diffusers import GenericDiffusersLoader
|
||||||
|
|
||||||
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||||
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||||
|
class ControlnetLoader(GenericDiffusersLoader):
|
||||||
|
"""Class to load ControlNet models."""
|
||||||
|
|
||||||
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
|
if config.format != ModelFormat.Checkpoint:
|
||||||
|
return False
|
||||||
|
elif (
|
||||||
|
dest_path.exists()
|
||||||
|
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||||
|
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||||
|
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||||
|
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||||
|
else:
|
||||||
|
assert hasattr(config, 'config')
|
||||||
|
config_file = config.config
|
||||||
|
|
||||||
|
if weights_path.suffix == ".safetensors":
|
||||||
|
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(weights_path, map_location="cpu")
|
||||||
|
|
||||||
|
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||||
|
if "state_dict" in checkpoint:
|
||||||
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
|
convert_controlnet_to_diffusers(
|
||||||
|
weights_path,
|
||||||
|
output_path,
|
||||||
|
original_config_file=self._app_config.root_path / config_file,
|
||||||
|
image_size=512,
|
||||||
|
scan_needed=True,
|
||||||
|
from_safetensors=weights_path.suffix == ".safetensors",
|
||||||
|
)
|
||||||
|
return output_path
|
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""Class for simple diffusers model loading in InvokeAI."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager import (
|
||||||
|
AnyModel,
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelRepoVariant,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||||
|
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||||
|
|
||||||
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||||
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
||||||
|
class GenericDiffusersLoader(ModelLoader):
|
||||||
|
"""Class to load simple diffusers models."""
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
model_path: Path,
|
||||||
|
model_variant: Optional[ModelRepoVariant] = None,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
model_class = self._get_hf_load_class(model_path)
|
||||||
|
if submodel_type is not None:
|
||||||
|
raise Exception(f"There are no submodels in models of type {model_class}")
|
||||||
|
variant = model_variant.value if model_variant else None
|
||||||
|
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
|
||||||
|
return result
|
@ -0,0 +1,39 @@
|
|||||||
|
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""Class for IP Adapter model loading in InvokeAI."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||||
|
from invokeai.backend.model_manager import (
|
||||||
|
AnyModel,
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelRepoVariant,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||||
|
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||||
|
|
||||||
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||||
|
class IPAdapterInvokeAILoader(ModelLoader):
|
||||||
|
"""Class to load IP Adapter diffusers models."""
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
model_path: Path,
|
||||||
|
model_variant: Optional[ModelRepoVariant] = None,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
if submodel_type is not None:
|
||||||
|
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||||
|
model = build_ip_adapter(
|
||||||
|
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
dtype=self._torch_dtype,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
76
invokeai/backend/model_manager/load/model_loaders/lora.py
Normal file
76
invokeai/backend/model_manager/load/model_loaders/lora.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""Class for LoRA model loading in InvokeAI."""
|
||||||
|
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.model_manager import (
|
||||||
|
AnyModel,
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelRepoVariant,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.lora import LoRAModelRaw
|
||||||
|
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||||
|
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||||
|
|
||||||
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
||||||
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
||||||
|
class LoraLoader(ModelLoader):
|
||||||
|
"""Class to load LoRA models."""
|
||||||
|
|
||||||
|
# We cheat a little bit to get access to the model base
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app_config: InvokeAIAppConfig,
|
||||||
|
logger: Logger,
|
||||||
|
ram_cache: ModelCacheBase[AnyModel],
|
||||||
|
convert_cache: ModelConvertCacheBase,
|
||||||
|
):
|
||||||
|
"""Initialize the loader."""
|
||||||
|
super().__init__(app_config, logger, ram_cache, convert_cache)
|
||||||
|
self._model_base: Optional[BaseModelType] = None
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
model_path: Path,
|
||||||
|
model_variant: Optional[ModelRepoVariant] = None,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
if submodel_type is not None:
|
||||||
|
raise ValueError("There are no submodels in a LoRA model.")
|
||||||
|
model = LoRAModelRaw.from_checkpoint(
|
||||||
|
file_path=model_path,
|
||||||
|
dtype=self._torch_dtype,
|
||||||
|
base_model=self._model_base,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
# override
|
||||||
|
def _get_model_path(
|
||||||
|
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||||
|
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||||
|
self._model_base = config.base # cheating a little - setting this variable for later call to _load_model()
|
||||||
|
|
||||||
|
model_base_path = self._app_config.models_path
|
||||||
|
model_path = model_base_path / config.path
|
||||||
|
|
||||||
|
if config.format == ModelFormat.Diffusers:
|
||||||
|
for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder
|
||||||
|
path = model_base_path / config.path / f"pytorch_lora_weights.{ext}"
|
||||||
|
if path.exists():
|
||||||
|
model_path = path
|
||||||
|
break
|
||||||
|
|
||||||
|
result = model_path.resolve(), config, submodel_type
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
@ -0,0 +1,93 @@
|
|||||||
|
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""Class for StableDiffusion model loading in InvokeAI."""
|
||||||
|
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager import (
|
||||||
|
AnyModel,
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelRepoVariant,
|
||||||
|
ModelType,
|
||||||
|
ModelVariantType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.config import MainCheckpointConfig
|
||||||
|
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||||
|
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||||
|
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||||
|
|
||||||
|
|
||||||
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||||
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||||
|
class StableDiffusionDiffusersModel(ModelLoader):
|
||||||
|
"""Class to load main models."""
|
||||||
|
|
||||||
|
model_base_to_model_type = {
|
||||||
|
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||||
|
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||||
|
BaseModelType.StableDiffusionXL: "SDXL",
|
||||||
|
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
model_path: Path,
|
||||||
|
model_variant: Optional[ModelRepoVariant] = None,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
if not submodel_type is not None:
|
||||||
|
raise Exception("A submodel type must be provided when loading main pipelines.")
|
||||||
|
load_class = self._get_hf_load_class(model_path, submodel_type)
|
||||||
|
variant = model_variant.value if model_variant else None
|
||||||
|
model_path = model_path / submodel_type.value
|
||||||
|
result: AnyModel = load_class.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=self._torch_dtype,
|
||||||
|
variant=variant,
|
||||||
|
) # type: ignore
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
|
if config.format != ModelFormat.Checkpoint:
|
||||||
|
return False
|
||||||
|
elif (
|
||||||
|
dest_path.exists()
|
||||||
|
and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||||
|
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||||
|
assert isinstance(config, MainCheckpointConfig)
|
||||||
|
variant = config.variant
|
||||||
|
base = config.base
|
||||||
|
pipeline_class = (
|
||||||
|
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
|
||||||
|
)
|
||||||
|
|
||||||
|
config_file = config.config
|
||||||
|
|
||||||
|
self._logger.info(f"Converting {weights_path} to diffusers format")
|
||||||
|
convert_ckpt_to_diffusers(
|
||||||
|
weights_path,
|
||||||
|
output_path,
|
||||||
|
model_type=self.model_base_to_model_type[base],
|
||||||
|
model_version=base,
|
||||||
|
model_variant=variant,
|
||||||
|
original_config_file=self._app_config.root_path / config_file,
|
||||||
|
extract_ema=True,
|
||||||
|
scan_needed=True,
|
||||||
|
pipeline_class=pipeline_class,
|
||||||
|
from_safetensors=weights_path.suffix == ".safetensors",
|
||||||
|
precision=self._torch_dtype,
|
||||||
|
load_safety_checker=False,
|
||||||
|
)
|
||||||
|
return output_path
|
@ -2,68 +2,54 @@
|
|||||||
"""Class for VAE model loading in InvokeAI."""
|
"""Class for VAE model loading in InvokeAI."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import safetensors
|
import safetensors
|
||||||
from omegaconf import OmegaConf, DictConfig
|
import torch
|
||||||
from invokeai.backend.util.devices import torch_dtype
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType
|
|
||||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
from invokeai.backend.model_manager import (
|
||||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||||
|
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||||
|
from .generic_diffusers import GenericDiffusersLoader
|
||||||
|
|
||||||
|
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
||||||
@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||||
@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||||
class VaeDiffusersModel(ModelLoader):
|
class VaeLoader(GenericDiffusersLoader):
|
||||||
"""Class to load VAE models."""
|
"""Class to load VAE models."""
|
||||||
|
|
||||||
def _load_model(
|
|
||||||
self,
|
|
||||||
model_path: Path,
|
|
||||||
model_variant: Optional[ModelRepoVariant] = None,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
|
||||||
) -> AnyModel:
|
|
||||||
if submodel_type is not None:
|
|
||||||
raise Exception("There are no submodels in VAEs")
|
|
||||||
vae_class = self._get_hf_load_class(model_path)
|
|
||||||
variant = model_variant.value if model_variant else None
|
|
||||||
result: AnyModel = vae_class.from_pretrained(
|
|
||||||
model_path, torch_dtype=self._torch_dtype, variant=variant
|
|
||||||
) # type: ignore
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
print(f'DEBUG: last_modified={config.last_modified}')
|
|
||||||
print(f'DEBUG: cache_path={(dest_path / "config.json").stat().st_mtime}')
|
|
||||||
print(f'DEBUG: model_path={model_path.stat().st_mtime}')
|
|
||||||
if config.format != ModelFormat.Checkpoint:
|
if config.format != ModelFormat.Checkpoint:
|
||||||
return False
|
return False
|
||||||
elif dest_path.exists() \
|
elif (
|
||||||
and (dest_path / "config.json").stat().st_mtime >= config.last_modified \
|
dest_path.exists()
|
||||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime:
|
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||||
|
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _convert_model(self,
|
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||||
config: AnyModelConfig,
|
# TO DO: check whether sdxl VAE models convert.
|
||||||
weights_path: Path,
|
|
||||||
output_path: Path
|
|
||||||
) -> Path:
|
|
||||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||||
else:
|
else:
|
||||||
config_file = 'v1-inference.yaml' if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
|
config_file = (
|
||||||
|
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
|
||||||
|
)
|
||||||
|
|
||||||
if weights_path.suffix == ".safetensors":
|
if weights_path.suffix == ".safetensors":
|
||||||
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(weights_path, map_location="cpu")
|
checkpoint = torch.load(weights_path, map_location="cpu")
|
||||||
|
|
||||||
dtype = torch_dtype()
|
|
||||||
|
|
||||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||||
if "state_dict" in checkpoint:
|
if "state_dict" in checkpoint:
|
||||||
checkpoint = checkpoint["state_dict"]
|
checkpoint = checkpoint["state_dict"]
|
||||||
@ -71,13 +57,11 @@ class VaeDiffusersModel(ModelLoader):
|
|||||||
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
|
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
|
||||||
assert isinstance(ckpt_config, DictConfig)
|
assert isinstance(ckpt_config, DictConfig)
|
||||||
|
|
||||||
print(f'DEBUG: CONVERTIGN')
|
|
||||||
vae_model = convert_ldm_vae_to_diffusers(
|
vae_model = convert_ldm_vae_to_diffusers(
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
vae_config=ckpt_config,
|
vae_config=ckpt_config,
|
||||||
image_size=512,
|
image_size=512,
|
||||||
)
|
)
|
||||||
vae_model.to(dtype) # set precision appropriately
|
vae_model.to(self._torch_dtype) # set precision appropriately
|
||||||
vae_model.save_pretrained(output_path, safe_serialization=True, torch_dtype=dtype)
|
vae_model.save_pretrained(output_path, safe_serialization=True)
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
|
@ -8,10 +8,11 @@ from typing import Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel
|
from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel
|
||||||
|
|
||||||
|
|
||||||
def calc_model_size_by_data(model: Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel]) -> int:
|
def calc_model_size_by_data(model: AnyModel) -> int:
|
||||||
"""Get size of a model in memory in bytes."""
|
"""Get size of a model in memory in bytes."""
|
||||||
if isinstance(model, DiffusionPipeline):
|
if isinstance(model, DiffusionPipeline):
|
||||||
return _calc_pipeline_by_data(model)
|
return _calc_pipeline_by_data(model)
|
||||||
|
620
invokeai/backend/model_manager/lora.py
Normal file
620
invokeai/backend/model_manager/lora.py
Normal file
@ -0,0 +1,620 @@
|
|||||||
|
# Copyright (c) 2024 The InvokeAI Development team
|
||||||
|
"""LoRA model support."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, Union, List, Tuple
|
||||||
|
from typing_extensions import Self
|
||||||
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
|
|
||||||
|
class LoRALayerBase:
|
||||||
|
# rank: Optional[int]
|
||||||
|
# alpha: Optional[float]
|
||||||
|
# bias: Optional[torch.Tensor]
|
||||||
|
# layer_key: str
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def scale(self):
|
||||||
|
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
if "alpha" in values:
|
||||||
|
self.alpha = values["alpha"].item()
|
||||||
|
else:
|
||||||
|
self.alpha = None
|
||||||
|
|
||||||
|
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||||
|
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||||
|
values["bias_indices"],
|
||||||
|
values["bias_values"],
|
||||||
|
tuple(values["bias_size"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
self.rank = None # set in layer implementation
|
||||||
|
self.layer_key = layer_key
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = 0
|
||||||
|
for val in [self.bias]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
) -> None:
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: find and debug lora/locon with bias
|
||||||
|
class LoRALayer(LoRALayerBase):
|
||||||
|
# up: torch.Tensor
|
||||||
|
# mid: Optional[torch.Tensor]
|
||||||
|
# down: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.up = values["lora_up.weight"]
|
||||||
|
self.down = values["lora_down.weight"]
|
||||||
|
if "lora_mid.weight" in values:
|
||||||
|
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
||||||
|
else:
|
||||||
|
self.mid = None
|
||||||
|
|
||||||
|
self.rank = self.down.shape[0]
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor):
|
||||||
|
if self.mid is not None:
|
||||||
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||||
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||||
|
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||||
|
else:
|
||||||
|
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.up, self.mid, self.down]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
) -> None:
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.up = self.up.to(device=device, dtype=dtype)
|
||||||
|
self.down = self.down.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.mid is not None:
|
||||||
|
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class LoHALayer(LoRALayerBase):
|
||||||
|
# w1_a: torch.Tensor
|
||||||
|
# w1_b: torch.Tensor
|
||||||
|
# w2_a: torch.Tensor
|
||||||
|
# w2_b: torch.Tensor
|
||||||
|
# t1: Optional[torch.Tensor] = None
|
||||||
|
# t2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: Dict[str, torch.Tensor]
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.w1_a = values["hada_w1_a"]
|
||||||
|
self.w1_b = values["hada_w1_b"]
|
||||||
|
self.w2_a = values["hada_w2_a"]
|
||||||
|
self.w2_b = values["hada_w2_b"]
|
||||||
|
|
||||||
|
if "hada_t1" in values:
|
||||||
|
self.t1: Optional[torch.Tensor] = values["hada_t1"]
|
||||||
|
else:
|
||||||
|
self.t1 = None
|
||||||
|
|
||||||
|
if "hada_t2" in values:
|
||||||
|
self.t2: Optional[torch.Tensor] = values["hada_t2"]
|
||||||
|
else:
|
||||||
|
self.t2 = None
|
||||||
|
|
||||||
|
self.rank = self.w1_b.shape[0]
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.t1 is None:
|
||||||
|
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
|
else:
|
||||||
|
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||||
|
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||||
|
weight = rebuild1 * rebuild2
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
) -> None:
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||||
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||||
|
if self.t1 is not None:
|
||||||
|
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||||
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||||
|
if self.t2 is not None:
|
||||||
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class LoKRLayer(LoRALayerBase):
|
||||||
|
# w1: Optional[torch.Tensor] = None
|
||||||
|
# w1_a: Optional[torch.Tensor] = None
|
||||||
|
# w1_b: Optional[torch.Tensor] = None
|
||||||
|
# w2: Optional[torch.Tensor] = None
|
||||||
|
# w2_a: Optional[torch.Tensor] = None
|
||||||
|
# w2_b: Optional[torch.Tensor] = None
|
||||||
|
# t2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
if "lokr_w1" in values:
|
||||||
|
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
||||||
|
self.w1_a = None
|
||||||
|
self.w1_b = None
|
||||||
|
else:
|
||||||
|
self.w1 = None
|
||||||
|
self.w1_a = values["lokr_w1_a"]
|
||||||
|
self.w1_b = values["lokr_w1_b"]
|
||||||
|
|
||||||
|
if "lokr_w2" in values:
|
||||||
|
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
||||||
|
self.w2_a = None
|
||||||
|
self.w2_b = None
|
||||||
|
else:
|
||||||
|
self.w2 = None
|
||||||
|
self.w2_a = values["lokr_w2_a"]
|
||||||
|
self.w2_b = values["lokr_w2_b"]
|
||||||
|
|
||||||
|
if "lokr_t2" in values:
|
||||||
|
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
||||||
|
else:
|
||||||
|
self.t2 = None
|
||||||
|
|
||||||
|
if "lokr_w1_b" in values:
|
||||||
|
self.rank = values["lokr_w1_b"].shape[0]
|
||||||
|
elif "lokr_w2_b" in values:
|
||||||
|
self.rank = values["lokr_w2_b"].shape[0]
|
||||||
|
else:
|
||||||
|
self.rank = None # unscaled
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
w1: Optional[torch.Tensor] = self.w1
|
||||||
|
if w1 is None:
|
||||||
|
assert self.w1_a is not None
|
||||||
|
assert self.w1_b is not None
|
||||||
|
w1 = self.w1_a @ self.w1_b
|
||||||
|
|
||||||
|
w2 = self.w2
|
||||||
|
if w2 is None:
|
||||||
|
if self.t2 is None:
|
||||||
|
assert self.w2_a is not None
|
||||||
|
assert self.w2_b is not None
|
||||||
|
w2 = self.w2_a @ self.w2_b
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
w2 = w2.contiguous()
|
||||||
|
assert w1 is not None
|
||||||
|
assert w2 is not None
|
||||||
|
weight = torch.kron(w1, w2)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
) -> None:
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.w1 is not None:
|
||||||
|
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
assert self.w1_a is not None
|
||||||
|
assert self.w1_b is not None
|
||||||
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||||
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.w2 is not None:
|
||||||
|
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
assert self.w2_a is not None
|
||||||
|
assert self.w2_b is not None
|
||||||
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||||
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.t2 is not None:
|
||||||
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class FullLayer(LoRALayerBase):
|
||||||
|
# weight: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.weight = values["diff"]
|
||||||
|
|
||||||
|
if len(values.keys()) > 1:
|
||||||
|
_keys = list(values.keys())
|
||||||
|
_keys.remove("diff")
|
||||||
|
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||||
|
|
||||||
|
self.rank = None # unscaled
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class IA3Layer(LoRALayerBase):
|
||||||
|
# weight: torch.Tensor
|
||||||
|
# on_input: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.weight = values["weight"]
|
||||||
|
self.on_input = values["on_input"]
|
||||||
|
|
||||||
|
self.rank = None # unscaled
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor):
|
||||||
|
weight = self.weight
|
||||||
|
if not self.on_input:
|
||||||
|
weight = weight.reshape(-1, 1)
|
||||||
|
return orig_weight * weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
|
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
|
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||||
|
|
||||||
|
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||||
|
class LoRAModelRaw: # (torch.nn.Module):
|
||||||
|
_name: str
|
||||||
|
layers: Dict[str, AnyLoRALayer]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
layers: Dict[str, AnyLoRALayer],
|
||||||
|
):
|
||||||
|
self._name = name
|
||||||
|
self.layers = layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
) -> None:
|
||||||
|
# TODO: try revert if exception?
|
||||||
|
for _key, layer in self.layers.items():
|
||||||
|
layer.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = 0
|
||||||
|
for _, layer in self.layers.items():
|
||||||
|
model_size += layer.calc_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||||
|
|
||||||
|
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||||
|
diffusers format, then this function will have no effect.
|
||||||
|
|
||||||
|
This function is adapted from:
|
||||||
|
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||||
|
"""
|
||||||
|
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||||
|
not_converted_count = 0 # The number of keys that were not converted.
|
||||||
|
|
||||||
|
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||||
|
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||||
|
# `input_blocks_4_1_proj_in`.
|
||||||
|
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||||
|
stability_unet_keys.sort()
|
||||||
|
|
||||||
|
new_state_dict = {}
|
||||||
|
for full_key, value in state_dict.items():
|
||||||
|
if full_key.startswith("lora_unet_"):
|
||||||
|
search_key = full_key.replace("lora_unet_", "")
|
||||||
|
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||||
|
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||||
|
map_key = stability_unet_keys[position - 1]
|
||||||
|
# Now, check if the map_key *actually* matches the search_key.
|
||||||
|
if search_key.startswith(map_key):
|
||||||
|
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||||
|
new_state_dict[new_key] = value
|
||||||
|
converted_count += 1
|
||||||
|
else:
|
||||||
|
new_state_dict[full_key] = value
|
||||||
|
not_converted_count += 1
|
||||||
|
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||||
|
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||||
|
new_state_dict[full_key] = value
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||||
|
|
||||||
|
if converted_count > 0 and not_converted_count > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||||
|
f" not_converted={not_converted_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_checkpoint(
|
||||||
|
cls,
|
||||||
|
file_path: Union[str, Path],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
base_model: Optional[BaseModelType] = None,
|
||||||
|
) -> Self:
|
||||||
|
device = device or torch.device("cpu")
|
||||||
|
dtype = dtype or torch.float32
|
||||||
|
|
||||||
|
if isinstance(file_path, str):
|
||||||
|
file_path = Path(file_path)
|
||||||
|
|
||||||
|
model = cls(
|
||||||
|
name=file_path.stem, # TODO:
|
||||||
|
layers={},
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_path.suffix == ".safetensors":
|
||||||
|
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(file_path, map_location="cpu")
|
||||||
|
|
||||||
|
state_dict = cls._group_state(state_dict)
|
||||||
|
|
||||||
|
if base_model == BaseModelType.StableDiffusionXL:
|
||||||
|
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||||
|
|
||||||
|
for layer_key, values in state_dict.items():
|
||||||
|
# lora and locon
|
||||||
|
if "lora_down.weight" in values:
|
||||||
|
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||||
|
|
||||||
|
# loha
|
||||||
|
elif "hada_w1_b" in values:
|
||||||
|
layer = LoHALayer(layer_key, values)
|
||||||
|
|
||||||
|
# lokr
|
||||||
|
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||||
|
layer = LoKRLayer(layer_key, values)
|
||||||
|
|
||||||
|
# diff
|
||||||
|
elif "diff" in values:
|
||||||
|
layer = FullLayer(layer_key, values)
|
||||||
|
|
||||||
|
# ia3
|
||||||
|
elif "weight" in values and "on_input" in values:
|
||||||
|
layer = IA3Layer(layer_key, values)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||||
|
raise Exception("Unknown lora format!")
|
||||||
|
|
||||||
|
# lower memory consumption by removing already parsed layer values
|
||||||
|
state_dict[layer_key].clear()
|
||||||
|
|
||||||
|
layer.to(device=device, dtype=dtype)
|
||||||
|
model.layers[layer_key] = layer
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||||
|
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||||
|
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
stem, leaf = key.split(".", 1)
|
||||||
|
if stem not in state_dict_groupped:
|
||||||
|
state_dict_groupped[stem] = {}
|
||||||
|
state_dict_groupped[stem][leaf] = value
|
||||||
|
|
||||||
|
return state_dict_groupped
|
||||||
|
|
||||||
|
|
||||||
|
# code from
|
||||||
|
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||||
|
def make_sdxl_unet_conversion_map() -> List[Tuple[str,str]]:
|
||||||
|
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||||
|
unet_conversion_map_layer = []
|
||||||
|
|
||||||
|
for i in range(3): # num_blocks is 3 in sdxl
|
||||||
|
# loop over downblocks/upblocks
|
||||||
|
for j in range(2):
|
||||||
|
# loop over resnets/attentions for downblocks
|
||||||
|
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||||
|
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no attention layers in down_blocks.3
|
||||||
|
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||||
|
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(3):
|
||||||
|
# loop over resnets/attentions for upblocks
|
||||||
|
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||||
|
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||||
|
|
||||||
|
# if i > 0: commentout for sdxl
|
||||||
|
# no attention layers in up_blocks.0
|
||||||
|
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||||
|
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no downsample in down_blocks.3
|
||||||
|
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||||
|
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||||
|
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||||
|
|
||||||
|
# no upsample in up_blocks.3
|
||||||
|
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||||
|
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||||
|
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||||
|
|
||||||
|
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||||
|
sd_mid_atn_prefix = "middle_block.1."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||||
|
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||||
|
|
||||||
|
unet_conversion_map_resnet = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("in_layers.0.", "norm1."),
|
||||||
|
("in_layers.2.", "conv1."),
|
||||||
|
("out_layers.0.", "norm2."),
|
||||||
|
("out_layers.3.", "conv2."),
|
||||||
|
("emb_layers.1.", "time_emb_proj."),
|
||||||
|
("skip_connection.", "conv_shortcut."),
|
||||||
|
]
|
||||||
|
|
||||||
|
unet_conversion_map = []
|
||||||
|
for sd, hf in unet_conversion_map_layer:
|
||||||
|
if "resnets" in hf:
|
||||||
|
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||||
|
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||||
|
else:
|
||||||
|
unet_conversion_map.append((sd, hf))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||||
|
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||||
|
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||||
|
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||||
|
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||||
|
|
||||||
|
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||||
|
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||||
|
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||||
|
|
||||||
|
return unet_conversion_map
|
||||||
|
|
||||||
|
|
||||||
|
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||||
|
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
||||||
|
}
|
@ -29,8 +29,12 @@ CkptType = Dict[str, Any]
|
|||||||
|
|
||||||
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
|
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
|
||||||
BaseModelType.StableDiffusion1: {
|
BaseModelType.StableDiffusion1: {
|
||||||
ModelVariantType.Normal: "v1-inference.yaml",
|
ModelVariantType.Normal: {
|
||||||
|
SchedulerPredictionType.Epsilon: "v1-inference.yaml",
|
||||||
|
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
|
||||||
|
},
|
||||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||||
|
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
||||||
},
|
},
|
||||||
BaseModelType.StableDiffusion2: {
|
BaseModelType.StableDiffusion2: {
|
||||||
ModelVariantType.Normal: {
|
ModelVariantType.Normal: {
|
||||||
|
@ -20,6 +20,14 @@ from .util import ( # TO DO: Clean this up; remove the unused symbols
|
|||||||
download_with_resume,
|
download_with_resume,
|
||||||
instantiate_from_config, # noqa
|
instantiate_from_config, # noqa
|
||||||
url_attachment_name, # noqa
|
url_attachment_name, # noqa
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = ["GIG", "directory_size","Chdir", "download_with_resume", "InvokeAILogger", "choose_precision", "choose_torch_device"]
|
__all__ = [
|
||||||
|
"GIG",
|
||||||
|
"directory_size",
|
||||||
|
"Chdir",
|
||||||
|
"download_with_resume",
|
||||||
|
"InvokeAILogger",
|
||||||
|
"choose_precision",
|
||||||
|
"choose_torch_device",
|
||||||
|
]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Union, Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
|
@ -27,6 +27,7 @@ from .devices import torch_dtype
|
|||||||
# actual size of a gig
|
# actual size of a gig
|
||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
|
|
||||||
|
|
||||||
def directory_size(directory: Path) -> int:
|
def directory_size(directory: Path) -> int:
|
||||||
"""
|
"""
|
||||||
Return the aggregate size of all files in a directory (bytes).
|
Return the aggregate size of all files in a directory (bytes).
|
||||||
@ -39,6 +40,7 @@ def directory_size(directory: Path) -> int:
|
|||||||
sum += Path(root, d).stat().st_size
|
sum += Path(root, d).stat().st_size
|
||||||
return sum
|
return sum
|
||||||
|
|
||||||
|
|
||||||
def log_txt_as_img(wh, xc, size=10):
|
def log_txt_as_img(wh, xc, size=10):
|
||||||
# wh a tuple of (width, height)
|
# wh a tuple of (width, height)
|
||||||
# xc a list of captions to plot
|
# xc a list of captions to plot
|
||||||
|
Loading…
Reference in New Issue
Block a user