diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 0f2a92b5c8..dcb8d21997 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -8,6 +8,8 @@ from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMe from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db +from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache +from invokeai.backend.model_manager.load.model_cache import ModelCache from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.util.logging import InvokeAILogger @@ -98,15 +100,26 @@ class ApiDependencies: ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) + model_loader = AnyModelLoader( + app_config=config, + logger=logger, + ram_cache=ModelCache( + max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger + ), + convert_cache=ModelConvertCache( + cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size + ), + ) + model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader) download_queue_service = DownloadQueueService(event_bus=events) - metadata_store = ModelMetadataStore(db=db) model_install_service = ModelInstallService( app_config=config, record_store=model_record_service, download_queue=download_queue_service, - metadata_store=metadata_store, + metadata_store=ModelMetadataStore(db=db), event_bus=events, ) + model_manager = ModelManagerService(config, logger) # TO DO: legacy model manager v1. Remove names = SimpleNameService() performance_statistics = InvocationStatsService() processor = DefaultInvocationProcessor() diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 132afc2272..b161ea18d6 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -237,6 +237,7 @@ class InvokeAIAppConfig(InvokeAISettings): autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths) conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths) models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths) + convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths) legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths) db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths) outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths) @@ -262,6 +263,8 @@ class InvokeAIAppConfig(InvokeAISettings): # CACHE ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) + convert_cache : float = Field(default=10.0, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache) + lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, ) log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache) @@ -404,6 +407,11 @@ class InvokeAIAppConfig(InvokeAISettings): """Path to the models directory.""" return self._resolve(self.models_dir) + @property + def models_convert_cache_path(self) -> Path: + """Path to the converted cache models directory.""" + return self._resolve(self.convert_cache_dir) + @property def custom_nodes_path(self) -> Path: """Path to the custom nodes directory.""" @@ -433,15 +441,20 @@ class InvokeAIAppConfig(InvokeAISettings): return True @property - def ram_cache_size(self) -> Union[Literal["auto"], float]: - """Return the ram cache size using the legacy or modern setting.""" + def ram_cache_size(self) -> float: + """Return the ram cache size using the legacy or modern setting (GB).""" return self.max_cache_size or self.ram @property - def vram_cache_size(self) -> Union[Literal["auto"], float]: - """Return the vram cache size using the legacy or modern setting.""" + def vram_cache_size(self) -> float: + """Return the vram cache size using the legacy or modern setting (GB).""" return self.max_vram_cache_size or self.vram + @property + def convert_cache_size(self) -> float: + """Return the convert cache size on disk (GB).""" + return self.convert_cache + @property def use_cpu(self) -> bool: """Return true if the device is set to CPU or the always_use_cpu flag is set.""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 82c667f584..2b2294bfce 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -145,7 +145,7 @@ class ModelInstallService(ModelInstallServiceBase): ) -> str: # noqa D102 model_path = Path(model_path) config = config or {} - if config.get("source") is None: + if not config.get("source"): config["source"] = model_path.resolve().as_posix() return self._register(model_path, config) @@ -156,7 +156,7 @@ class ModelInstallService(ModelInstallServiceBase): ) -> str: # noqa D102 model_path = Path(model_path) config = config or {} - if config.get("source") is None: + if not config.get("source"): config["source"] = model_path.resolve().as_posix() info: AnyModelConfig = self._probe_model(Path(model_path), config) @@ -300,6 +300,7 @@ class ModelInstallService(ModelInstallServiceBase): job.total_bytes = self._stat_size(job.local_path) job.bytes = job.total_bytes self._signal_job_running(job) + job.config_in["source"] = str(job.source) if job.inplace: key = self.register_path(job.local_path, job.config_in) else: diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 57597570cd..31cfecb4ec 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -11,7 +11,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults -from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType +from invokeai.backend.model_manager import LoadedModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore @@ -102,6 +102,19 @@ class ModelRecordServiceBase(ABC): """ pass + @abstractmethod + def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel: + """ + Load the indicated model into memory and return a LoadedModel object. + + :param key: Key of model config to be fetched. + :param submodel_type: For main (pipeline models), the submodel to fetch + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + """ + pass + @property @abstractmethod def metadata_store(self) -> ModelMetadataStore: diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 4512da5d41..eee867ccb4 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -42,6 +42,7 @@ Typical usage: import json import sqlite3 +import time from math import ceil from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union @@ -53,8 +54,10 @@ from invokeai.backend.model_manager.config import ( ModelConfigFactory, ModelFormat, ModelType, + SubModelType, ) from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException +from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_records_base import ( @@ -69,16 +72,17 @@ from .model_records_base import ( class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - def __init__(self, db: SqliteDatabase): + def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader]=None): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. - :param conn: sqlite3 connection object - :param lock: threading Lock object + :param db: Sqlite connection object + :param loader: Initialized model loader object (optional) """ super().__init__() self._db = db - self._cursor = self._db.conn.cursor() + self._cursor = db.conn.cursor() + self._loader = loader @property def db(self) -> SqliteDatabase: @@ -199,7 +203,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE id=?; """, (key,), @@ -207,9 +211,24 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): rows = self._cursor.fetchone() if not rows: raise UnknownModelException("model not found") - model = ModelConfigFactory.make_config(json.loads(rows[0])) + model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) return model + def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel: + """ + Load the indicated model into memory and return a LoadedModel object. + + :param key: Key of model config to be fetched. + :param submodel_type: For main (pipeline models), the submodel to fetch. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + """ + if not self._loader: + raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader") + model_config = self.get_model(key) + return self._loader.load_model(model_config, submodel_type) + def exists(self, key: str) -> bool: """ Return True if a model with the indicated key exists in the databse. @@ -265,12 +284,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): with self._db.lock: self._cursor.execute( f"""--sql - select config FROM model_config + select config, strftime('%s',updated_at) FROM model_config {where}; """, tuple(bindings), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] return results def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: @@ -279,12 +298,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE path=?; """, (str(path),), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] return results def search_by_hash(self, hash: str) -> List[AnyModelConfig]: @@ -293,12 +312,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE original_hash=?; """, (hash,), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] return results @property diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 6079b3f08d..681886eacd 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -8,6 +8,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -33,6 +34,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_3(app_config=config, logger=logger)) migrator.register_migration(build_migration_4()) migrator.register_migration(build_migration_5()) + migrator.register_migration(build_migration_6()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py new file mode 100644 index 0000000000..e72878f726 --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py @@ -0,0 +1,44 @@ +import sqlite3 +from logging import Logger + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + +class Migration6Callback: + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._recreate_model_triggers(cursor) + + def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None: + """ + Adds the timestamp trigger to the model_config table. + + This trigger was inadvertently dropped in earlier migration scripts. + """ + + cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS model_config_updated_at + AFTER UPDATE + ON model_config FOR EACH ROW + BEGIN + UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE id = old.id; + END; + """ + ) + +def build_migration_6() -> Migration: + """ + Build the migration from database version 5 to 6. + + This migration does the following: + - Adds the model_config_updated_at trigger if it does not exist + """ + migration_6 = Migration( + from_version=5, + to_version=6, + callback=Migration6Callback(), + ) + + return migration_6 diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index e54be527d9..8c03d2ccf8 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -98,11 +98,13 @@ class TqdmEventService(EventServiceBase): super().__init__() self._bars: Dict[str, tqdm] = {} self._last: Dict[str, int] = {} + self._logger = InvokeAILogger.get_logger(__name__) def dispatch(self, event_name: str, payload: Any) -> None: """Dispatch an event by appending it to self.events.""" + data = payload["data"] + source = data["source"] if payload["event"] == "model_install_downloading": - data = payload["data"] dest = data["local_path"] total_bytes = data["total_bytes"] bytes = data["bytes"] @@ -111,7 +113,12 @@ class TqdmEventService(EventServiceBase): self._last[dest] = 0 self._bars[dest].update(bytes - self._last[dest]) self._last[dest] = bytes - + elif payload["event"] == "model_install_completed": + self._logger.info(f"{source}: installed successfully.") + elif payload["event"] == "model_install_error": + self._logger.warning(f"{source}: installation failed with error {data['error']}") + elif payload["event"] == "model_install_cancelled": + self._logger.warning(f"{source}: installation cancelled") class InstallHelper(object): """Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db.""" diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 0f16852c93..f3c84cd01f 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -1,6 +1,7 @@ """Re-export frequently-used symbols from the Model Manager backend.""" from .config import ( + AnyModel, AnyModelConfig, BaseModelType, InvalidModelConfigException, @@ -14,12 +15,15 @@ from .config import ( ) from .probe import ModelProbe from .search import ModelSearch +from .load import LoadedModel __all__ = [ + "AnyModel", "AnyModelConfig", "BaseModelType", "ModelRepoVariant", "InvalidModelConfigException", + "LoadedModel", "ModelConfigFactory", "ModelFormat", "ModelProbe", diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 338669c873..796ccbacde 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -19,12 +19,15 @@ Typical usage: Validation errors will raise an InvalidModelConfigException error. """ +import time +import torch from enum import Enum from typing import Literal, Optional, Type, Union from pydantic import BaseModel, ConfigDict, Field, TypeAdapter +from diffusers import ModelMixin from typing_extensions import Annotated, Any, Dict - +from .onnx_runtime import IAIOnnxRuntimeModel class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" @@ -127,6 +130,7 @@ class ModelConfigBase(BaseModel): ) # if model is converted or otherwise modified, this will hold updated hash description: Optional[str] = Field(default=None) source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None) + last_modified: Optional[float] = Field(description="Timestamp for modification time", default_factory=time.time) model_config = ConfigDict( use_enum_values=False, @@ -280,6 +284,7 @@ AnyModelConfig = Union[ ] AnyModelConfigValidator = TypeAdapter(AnyModelConfig) +AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel] # IMPLEMENTATION NOTE: # The preferred alternative to the above is a discriminated Union as shown @@ -312,6 +317,7 @@ class ModelConfigFactory(object): model_data: Union[dict, AnyModelConfig], key: Optional[str] = None, dest_class: Optional[Type] = None, + timestamp: Optional[float] = None ) -> AnyModelConfig: """ Return the appropriate config object from raw dict values. @@ -330,4 +336,6 @@ class ModelConfigFactory(object): model = AnyModelConfigValidator.validate_python(model_data) if key: model.key = key + if timestamp: + model.last_modified = timestamp return model diff --git a/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py new file mode 100644 index 0000000000..9d6fc4841f --- /dev/null +++ b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py @@ -0,0 +1,1744 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted for use in InvokeAI by Lincoln Stein, July 2023 +# +""" Conversion script for the Stable Diffusion checkpoints.""" + +import re +from contextlib import nullcontext +from io import BytesIO +from pathlib import Path +from typing import Optional, Union + +import requests +import torch +from diffusers.models import AutoencoderKL, ControlNetModel, PriorTransformer, UNet2DConditionModel +from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UnCLIPScheduler, +) +from diffusers.utils import is_accelerate_available +from diffusers.utils.import_utils import BACKENDS_MAPPING +from picklescan.scanner import scan_file_path +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.util.logging import InvokeAILogger +from invokeai.backend.model_manager import BaseModelType, ModelVariantType + +try: + from omegaconf import OmegaConf + from omegaconf.dictconfig import DictConfig +except ImportError: + raise ImportError( + "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." + ) + +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + +logger = InvokeAILogger.get_logger(__name__) +CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core/convert" + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config.model.params.control_stage_config.params + else: + if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: + unet_params = original_config.model.params.unet_config.params + else: + unet_params = original_config.model.params.network_config.params + + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for _i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if unet_params.transformer_depth is not None: + transformer_layers_per_block = ( + unet_params.transformer_depth + if isinstance(unet_params.transformer_depth, int) + else list(unet_params.transformer_depth) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params.model_channels // unet_params.num_head_channels + head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params.context_dim is not None: + context_dim = ( + unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] + ) + + if "num_classes" in unet_params: + if unet_params.num_classes == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params.adm_in_channels + else: + raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if controlnet: + config["conditioning_channels"] = unet_params.hint_channels + else: + config["out_channels"] = unet_params.out_channels + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config.model.params.first_stage_config.params.ddconfig + _ = original_config.model.params.first_stage_config.params.embed_dim + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params.in_channels, + "out_channels": vae_params.out_ch, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params.z_channels, + "layers_per_block": vae_params.num_res_blocks, + } + return config + + +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config.model.params.timesteps, + beta_start=original_config.model.params.linear_start, + beta_end=original_config.model.params.linear_end, + beta_schedule="scaled_linear", + ) + return schedular + + +def create_ldm_bert_config(original_config): + bert_params = original_config.model.parms.cond_stage_config.params + config = LDMBertConfig( + d_model=bert_params.n_embed, + encoder_layers=bert_params.n_layer, + encoder_ffn_dim=bert_params.n_embed * 4, + ) + return config + + +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_bert_checkpoint(checkpoint, config): + def _copy_attn_layer(hf_attn_layer, pt_attn_layer): + hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight + hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight + hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight + + hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight + hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias + + def _copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + def _copy_layer(hf_layer, pt_layer): + # copy layer norms + _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) + _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) + + # copy attn + _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) + + # copy MLP + pt_mlp = pt_layer[1][1] + _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) + _copy_linear(hf_layer.fc2, pt_mlp.net[2]) + + def _copy_layers(hf_layers, pt_layers): + for i, hf_layer in enumerate(hf_layers): + if i != 0: + i += i + pt_layer = pt_layers[i : i + 2] + _copy_layer(hf_layer, pt_layer) + + hf_model = LDMBertModel(config).eval() + + # copy embeds + hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight + hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + + # copy layer norm + _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) + + # copy hidden layers + _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) + + _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) + + return hf_model + + +def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): + if text_encoder is None: + config = CLIPTextConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModel(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + text_model.load_state_dict(text_model_dict) + + return text_model + + +textenc_conversion_lst = [ + ("positional_embedding", "text_model.embeddings.position_embedding.weight"), + ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("ln_final.weight", "text_model.final_layer_norm.weight"), + ("ln_final.bias", "text_model.final_layer_norm.bias"), + ("text_projection", "text_projection.weight"), +] +textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} + +textenc_transformer_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + + +def convert_paint_by_example_checkpoint(checkpoint): + config = CLIPVisionConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + model = PaintByExampleImageEncoder(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # load clip vision + model.model.load_state_dict(text_model_dict) + + # load mapper + keys_mapper = { + k[len("cond_stage_model.mapper.res") :]: v + for k, v in checkpoint.items() + if k.startswith("cond_stage_model.mapper") + } + + MAPPING = { + "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], + "attn.c_proj": ["attn1.to_out.0"], + "ln_1": ["norm1"], + "ln_2": ["norm3"], + "mlp.c_fc": ["ff.net.0.proj"], + "mlp.c_proj": ["ff.net.2"], + } + + mapped_weights = {} + for key, value in keys_mapper.items(): + prefix = key[: len("blocks.i")] + suffix = key.split(prefix)[-1].split(".")[-1] + name = key.split(prefix)[-1].split(suffix)[0][1:-1] + mapped_names = MAPPING[name] + + num_splits = len(mapped_names) + for i, mapped_name in enumerate(mapped_names): + new_name = ".".join([prefix, mapped_name, suffix]) + shape = value.shape[0] // num_splits + mapped_weights[new_name] = value[i * shape : (i + 1) * shape] + + model.mapper.load_state_dict(mapped_weights) + + # load final layer norm + model.final_layer_norm.load_state_dict( + { + "bias": checkpoint["cond_stage_model.final_ln.bias"], + "weight": checkpoint["cond_stage_model.final_ln.weight"], + } + ) + + # load final proj + model.proj_out.load_state_dict( + { + "bias": checkpoint["proj_out.bias"], + "weight": checkpoint["proj_out.weight"], + } + ) + + # load uncond vector + model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) + return model + + +def convert_open_clip_checkpoint( + checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs +): + # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + # text_model = CLIPTextModelWithProjection.from_pretrained( + # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 + # ) + config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) + + keys = list(checkpoint.keys()) + + keys_to_ignore = [] + if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: + # make sure to remove all keys > 22 + keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] + keys_to_ignore += ["cond_stage_model.model.text_projection"] + + text_model_dict = {} + + if prefix + "text_projection" in checkpoint: + d_model = int(checkpoint[prefix + "text_projection"].shape[0]) + else: + d_model = 1024 + + text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + + for key in keys: + if key in keys_to_ignore: + continue + if key[len(prefix) :] in textenc_conversion_map: + if key.endswith("text_projection"): + value = checkpoint[key].T.contiguous() + else: + value = checkpoint[key] + + text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value + + if key.startswith(prefix + "transformer."): + new_key = key[len(prefix + "transformer.") :] + if new_key.endswith(".in_proj_weight"): + new_key = new_key[: -len(".in_proj_weight")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + elif new_key.endswith(".in_proj_bias"): + new_key = new_key[: -len(".in_proj_bias")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + else: + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + + text_model_dict[new_key] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + text_model.load_state_dict(text_model_dict) + + return text_model + + +def stable_unclip_image_encoder(original_config): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config.model.params.embedder_config + + sd_clip_image_embedder_class = image_embedder_config.target + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + else: + raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + # InvokeAI doesn't use CLIPVisionModelWithProjection so it isn't in the core - if this code is hit a download will occur + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + CONVERT_MODEL_ROOT / "CLIP-ViT-H-14-laion2B-s32B-b79K" + ) + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def stable_unclip_image_noising_components( + original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None +): + """ + Returns the noising components for the img2img and txt2img unclip pipelines. + + Converts the stability noise augmentor into + 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats + 2. a `DDPMScheduler` for holding the noise schedule + + If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. + """ + noise_aug_config = original_config.model.params.noise_aug_config + noise_aug_class = noise_aug_config.target + noise_aug_class = noise_aug_class.split(".")[-1] + + if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": + noise_aug_config = noise_aug_config.params + embedding_dim = noise_aug_config.timestep_dim + max_noise_level = noise_aug_config.noise_schedule_config.timesteps + beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule + + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) + image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) + + if "clip_stats_path" in noise_aug_config: + if clip_stats_path is None: + raise ValueError("This stable unclip config requires a `clip_stats_path`") + + clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + clip_stats_state_dict = { + "mean": clip_mean, + "std": clip_std, + } + + image_normalizer.load_state_dict(clip_stats_state_dict) + else: + raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") + + return image_normalizer, image_noising_scheduler + + +def convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=None, + cross_attention_dim=None, + precision: Optional[torch.dtype] = None, +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + original_config = ctrlnet_config.copy() + + ctrlnet_config.pop("addition_embed_type") + ctrlnet_config.pop("addition_time_embed_dim") + ctrlnet_config.pop("transformer_layers_per_block") + + if use_linear_projection is not None: + ctrlnet_config["use_linear_projection"] = use_linear_projection + + if cross_attention_dim is not None: + ctrlnet_config["cross_attention_dim"] = cross_attention_dim + + controlnet = ControlNetModel(**ctrlnet_config) + + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + skip_extract_state_dict = True + else: + skip_extract_state_dict = False + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, + original_config, + path=checkpoint_path, + extract_ema=extract_ema, + controlnet=True, + skip_extract_state_dict=skip_extract_state_dict, + ) + + controlnet.load_state_dict(converted_ctrl_checkpoint) + + return controlnet.to(precision) + + +def download_from_original_stable_diffusion_ckpt( + checkpoint_path: str, + model_version: BaseModelType, + model_variant: ModelVariantType, + original_config_file: str = None, + image_size: Optional[int] = None, + prediction_type: str = None, + model_type: str = None, + extract_ema: bool = False, + precision: Optional[torch.dtype] = None, + scheduler_type: str = "pndm", + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + stable_unclip: Optional[str] = None, + stable_unclip_prior: Optional[str] = None, + clip_stats_path: Optional[str] = None, + controlnet: Optional[bool] = None, + load_safety_checker: bool = True, + pipeline_class: DiffusionPipeline = None, + local_files_only=False, + vae_path=None, + text_encoder=None, + tokenizer=None, + scan_needed: bool = True, +) -> DiffusionPipeline: + """ + Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` + config file. + + Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the + global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is + recommended that you override the default values and/or supply an `original_config_file` wherever possible. + + Args: + checkpoint_path (`str`): Path to `.ckpt` file. + original_config_file (`str`): + Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically + inferred by looking for a key that only exists in SD2.0 models. + image_size (`int`, *optional*, defaults to 512): + The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 + Base. Use 768 for Stable Diffusion v2. + prediction_type (`str`, *optional*): + The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable + Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. + num_in_channels (`int`, *optional*, defaults to None): + The number of input channels. If `None`, it will be automatically inferred. + scheduler_type (`str`, *optional*, defaults to 'pndm'): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + model_type (`str`, *optional*, defaults to `None`): + The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", + "FrozenCLIPEmbedder", "PaintByExample"]`. + is_img2img (`bool`, *optional*, defaults to `False`): + Whether the model should be loaded as an img2img pipeline. + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to + `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for + inference. Non-EMA weights are usually better to continue fine-tuning. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. This is necessary when running stable + diffusion 2.1. + device (`str`, *optional*, defaults to `None`): + The device to use. Pass `None` to determine automatically. + from_safetensors (`str`, *optional*, defaults to `False`): + If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. + load_safety_checker (`bool`, *optional*, defaults to `True`): + Whether to load the safety checker or not. Defaults to `True`. + pipeline_class (`str`, *optional*, defaults to `None`): + The pipeline class to use. Pass `None` to determine automatically. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): + An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) + to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) + variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. + tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): + An instance of + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if + needed. + precision (`torch.dtype`, *optional*, defauts to `None`): + If not provided the precision will be set to the precision of the original file. + return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + """ + + # import pipelines here to avoid circular import error when using from_single_file method + from diffusers import ( + LDMTextToImagePipeline, + PaintByExamplePipeline, + StableDiffusionControlNetPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLPipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + + if pipeline_class is None: + pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline + + if prediction_type == "v-prediction": + prediction_type = "v_prediction" + + if from_safetensors: + from safetensors.torch import load_file as safe_load + + checkpoint = safe_load(checkpoint_path, device="cpu") + else: + if scan_needed: + # scan model + scan_result = scan_file_path(checkpoint_path) + if scan_result.infected_files != 0: + raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Sometimes models don't have the global_step item + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + else: + logger.debug("global_step key not found in model") + global_step = None + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}") + + precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias" + logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}") + precision = precision or checkpoint[precision_probing_key].dtype + + if original_config_file is None: + key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" + key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" + + # model_type = "v1" + config_url = ( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + ) + + if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: + # model_type = "v2" + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" + + if global_step == 110000: + # v2.1 needs to upcast attention + upcast_attention = True + elif key_name_sd_xl_base in checkpoint: + # only base xl has two text embedders + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" + elif key_name_sd_xl_refiner in checkpoint: + # only refiner xl has embedder and one text embedders + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" + + original_config_file = BytesIO(requests.get(config_url).content) + + original_config = OmegaConf.load(original_config_file) + if original_config["model"]["params"].get("use_ema") is not None: + extract_ema = original_config["model"]["params"]["use_ema"] + + if ( + model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1] + and original_config["model"]["params"].get("parameterization") == "v" + ): + prediction_type = "v_prediction" + upcast_attention = True + image_size = 768 if model_version == BaseModelType.StableDiffusion2 else 512 + else: + prediction_type = "epsilon" + upcast_attention = False + image_size = 512 + + # Convert the text model. + if ( + model_type is None + and "cond_stage_config" in original_config.model.params + and original_config.model.params.cond_stage_config is not None + ): + model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] + logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") + elif model_type is None and original_config.model.params.network_config is not None: + if original_config.model.params.network_config.params.context_dim == 2048: + model_type = "SDXL" + else: + model_type = "SDXL-Refiner" + if image_size is None: + image_size = 1024 + + if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline: + num_in_channels = 9 + elif num_in_channels is None: + num_in_channels = 4 + + if "unet_config" in original_config.model.params: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + if image_size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + else: + if prediction_type is None: + prediction_type = "epsilon" + if image_size is None: + image_size = 512 + + if controlnet is None and "control_stage_config" in original_config.model.params: + controlnet = convert_controlnet_checkpoint( + checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema + ) + + num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 + + if model_type in ["SDXL", "SDXL-Refiner"]: + scheduler_dict = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": num_train_timesteps, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", + } + scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) + scheduler_type = "euler" + else: + beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 + beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["upcast_attention"] = upcast_attention + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + unet = UNet2DConditionModel(**unet_config) + + if is_accelerate_available(): + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + else: + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model. + if vae_path is None: + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if ( + "model" in original_config + and "params" in original_config.model + and "scale_factor" in original_config.model.params + ): + vae_scaling_factor = original_config.model.params.scale_factor + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + vae = AutoencoderKL(**vae_config) + + if is_accelerate_available(): + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + else: + vae.load_state_dict(converted_vae_checkpoint) + else: + vae = AutoencoderKL.from_pretrained(vae_path) + + if model_type == "FrozenOpenCLIPEmbedder": + config_name = "stabilityai/stable-diffusion-2" + config_kwargs = {"subfolder": "text_encoder"} + + text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs) + tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer") + + if stable_unclip is None: + if controlnet: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + scheduler=scheduler, + controlnet=controlnet, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( + original_config, clip_stats_path=clip_stats_path, device=device + ) + + if stable_unclip == "img2img": + feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) + + pipe = StableUnCLIPImg2ImgPipeline( + # image encoding components + feature_extractor=feature_extractor, + image_encoder=image_encoder, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model.to(precision), + unet=unet.to(precision), + scheduler=scheduler, + # vae + vae=vae, + ) + elif stable_unclip == "txt2img": + if stable_unclip_prior is None or stable_unclip_prior == "karlo": + karlo_model = "kakaobrain/karlo-v1-alpha" + prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior") + + prior_tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + prior_text_model = CLIPTextModelWithProjection.from_pretrained( + CONVERT_MODEL_ROOT / "clip-vit-large-patch14" + ) + + prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler") + prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) + else: + raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") + + pipe = StableUnCLIPPipeline( + # prior components + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_model, + prior=prior, + prior_scheduler=prior_scheduler, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, + ) + else: + raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") + elif model_type == "PaintByExample": + vision_model = convert_paint_by_example_checkpoint(checkpoint) + tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + feature_extractor = AutoFeatureExtractor.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker") + pipe = PaintByExamplePipeline( + vae=vae, + image_encoder=vision_model, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=feature_extractor, + ) + elif model_type == "FrozenCLIPEmbedder": + text_model = convert_ldm_clip_checkpoint( + checkpoint, local_files_only=local_files_only, text_encoder=text_encoder + ) + tokenizer = ( + CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + if tokenizer is None + else tokenizer + ) + + if load_safety_checker: + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker" + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker" + ) + else: + safety_checker = None + feature_extractor = None + + if controlnet: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + elif model_type in ["SDXL", "SDXL-Refiner"]: + if model_type == "SDXL": + tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + + tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k" + tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!") + + config_name = tokenizer_name + config_kwargs = {"projection_dim": 1280} + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs + ) + + pipe = StableDiffusionXLPipeline( + vae=vae.to(precision), + text_encoder=text_encoder.to(precision), + tokenizer=tokenizer, + text_encoder_2=text_encoder_2.to(precision), + tokenizer_2=tokenizer_2, + unet=unet.to(precision), + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) + else: + tokenizer = None + text_encoder = None + tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k" + tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!") + + config_name = tokenizer_name + config_kwargs = {"projection_dim": 1280} + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs + ) + + pipe = StableDiffusionXLImg2ImgPipeline( + vae=vae.to(precision), + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet.to(precision), + scheduler=scheduler, + requires_aesthetics_score=True, + force_zeros_for_empty_prompt=False, + ) + else: + text_config = create_ldm_bert_config(original_config) + text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) + tokenizer = BertTokenizerFast.from_pretrained(CONVERT_MODEL_ROOT / "bert-base-uncased") + pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + return pipe + + +def download_controlnet_from_original_ckpt( + checkpoint_path: str, + original_config_file: str, + image_size: int = 512, + extract_ema: bool = False, + precision: Optional[torch.dtype] = None, + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + use_linear_projection: Optional[bool] = None, + cross_attention_dim: Optional[bool] = None, + scan_needed: bool = False, +) -> DiffusionPipeline: + + from omegaconf import OmegaConf + + if from_safetensors: + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if scan_needed: + # scan model + scan_result = scan_file_path(checkpoint_path) + if scan_result.infected_files != 0: + raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + # use original precision + precision_probing_key = "input_blocks.0.0.bias" + ckpt_precision = checkpoint[precision_probing_key].dtype + logger.debug(f"original controlnet precision = {ckpt_precision}") + precision = precision or ckpt_precision + + original_config = OmegaConf.load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if "control_stage_config" not in original_config.model.params: + raise ValueError("`control_stage_config` not present in original config") + + controlnet = convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, + ) + + return controlnet.to(precision) + + +def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL: + vae_config = create_vae_diffusers_config(vae_config, image_size=image_size) + + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae + + +def convert_ckpt_to_diffusers( + checkpoint_path: Union[str, Path], + dump_path: Union[str, Path], + use_safetensors: bool = True, + **kwargs, +): + """ + Takes all the arguments of download_from_original_stable_diffusion_ckpt(), + and in addition a path-like object indicating the location of the desired diffusers + model to be written. + """ + pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs) + + # TO DO: save correct repo variant + pipe.save_pretrained( + dump_path, + safe_serialization=use_safetensors, + ) + + +def convert_controlnet_to_diffusers( + checkpoint_path: Union[str, Path], + dump_path: Union[str, Path], + **kwargs, +): + """ + Takes all the arguments of download_controlnet_from_original_ckpt(), + and in addition a path-like object indicating the location of the desired diffusers + model to be written. + """ + pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs) + + # TO DO: save correct repo variant + pipe.save_pretrained(dump_path, safe_serialization=True) diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index e69de29bb2..357677bb7f 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development Team +""" +Init file for the model loader. +""" +from importlib import import_module +from pathlib import Path +from typing import Optional + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.util.logging import InvokeAILogger +from .load_base import AnyModelLoader, LoadedModel +from .model_cache.model_cache_default import ModelCache +from .convert_cache.convert_cache_default import ModelConvertCache + +# This registers the subclasses that implement loaders of specific model types +loaders = [x.stem for x in Path(Path(__file__).parent,'model_loaders').glob('*.py') if x.stem != '__init__'] +for module in loaders: + print(f'module={module}') + import_module(f"{__package__}.model_loaders.{module}") + +__all__ = ["AnyModelLoader", "LoadedModel"] + + +def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader: + app_config = app_config or InvokeAIAppConfig.get_config() + logger = InvokeAILogger.get_logger(config=app_config) + return AnyModelLoader(app_config=app_config, + logger=logger, + ram_cache=ModelCache(logger=logger, + max_cache_size=app_config.ram_cache_size, + max_vram_cache_size=app_config.vram_cache_size + ), + convert_cache=ModelConvertCache(app_config.models_convert_cache_path) + ) + diff --git a/invokeai/backend/model_manager/load/convert_cache/__init__.py b/invokeai/backend/model_manager/load/convert_cache/__init__.py new file mode 100644 index 0000000000..eb3149be32 --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/__init__.py @@ -0,0 +1,4 @@ +from .convert_cache_base import ModelConvertCacheBase +from .convert_cache_default import ModelConvertCache + +__all__ = ['ModelConvertCacheBase', 'ModelConvertCache'] diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py new file mode 100644 index 0000000000..25263f96aa --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py @@ -0,0 +1,28 @@ +""" +Disk-based converted model cache. +""" +from abc import ABC, abstractmethod +from pathlib import Path + +class ModelConvertCacheBase(ABC): + + @property + @abstractmethod + def max_size(self) -> float: + """Return the maximum size of this cache directory.""" + pass + + @abstractmethod + def make_room(self, size: float) -> None: + """ + Make sufficient room in the cache directory for a model of max_size. + + :param size: Size required (GB) + """ + pass + + @abstractmethod + def cache_path(self, key: str) -> Path: + """Return the path for a model with the indicated key.""" + pass + diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py new file mode 100644 index 0000000000..f799510ec5 --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py @@ -0,0 +1,64 @@ +""" +Placeholder for convert cache implementation. +""" + +from pathlib import Path +import shutil +from invokeai.backend.util.logging import InvokeAILogger +from invokeai.backend.util import GIG, directory_size +from .convert_cache_base import ModelConvertCacheBase + +class ModelConvertCache(ModelConvertCacheBase): + + def __init__(self, cache_path: Path, max_size: float=10.0): + """Initialize the convert cache with the base directory and a limit on its maximum size (in GBs).""" + if not cache_path.exists(): + cache_path.mkdir(parents=True) + self._cache_path = cache_path + self._max_size = max_size + + @property + def max_size(self) -> float: + """Return the maximum size of this cache directory (GB).""" + return self._max_size + + def cache_path(self, key: str) -> Path: + """Return the path for a model with the indicated key.""" + return self._cache_path / key + + def make_room(self, size: float) -> None: + """ + Make sufficient room in the cache directory for a model of max_size. + + :param size: Size required (GB) + """ + size_needed = directory_size(self._cache_path) + size + max_size = int(self.max_size) * GIG + logger = InvokeAILogger.get_logger() + + if size_needed <= max_size: + return + + logger.debug( + f"Convert cache has gotten too large {(size_needed / GIG):4.2f} > {(max_size / GIG):4.2f}G.. Trimming." + ) + + # For this to work, we make the assumption that the directory contains + # a 'model_index.json', 'unet/config.json' file, or a 'config.json' file at top level. + # This should be true for any diffusers model. + def by_atime(path: Path) -> float: + for config in ["model_index.json", "unet/config.json", "config.json"]: + sentinel = path / config + if sentinel.exists(): + return sentinel.stat().st_atime + return 0.0 + + # sort by last access time - least accessed files will be at the end + lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True) + logger.debug(f"cached models in descending atime order: {lru_models}") + while size_needed > max_size and len(lru_models) > 0: + next_victim = lru_models.pop() + victim_size = directory_size(next_victim) + logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB") + shutil.rmtree(next_victim) + size_needed -= victim_size diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 7cb7222b71..3ade83160a 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -16,39 +16,11 @@ from logging import Logger from pathlib import Path from typing import Any, Callable, Dict, Optional, Type, Union -import torch -from diffusers import DiffusionPipeline -from injector import inject - from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase -from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel -from invokeai.backend.model_manager.ram_cache import ModelCacheBase - -AnyModel = Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel] - - -class ModelLockerBase(ABC): - """Base class for the model locker used by the loader.""" - - @abstractmethod - def lock(self) -> None: - """Lock the contained model and move it into VRAM.""" - pass - - @abstractmethod - def unlock(self) -> None: - """Unlock the contained model, and remove it from VRAM.""" - pass - - @property - @abstractmethod - def model(self) -> AnyModel: - """Return the model.""" - pass - +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase +from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLockerBase +from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase @dataclass class LoadedModel: @@ -69,7 +41,7 @@ class LoadedModel: @property def model(self) -> AnyModel: """Return the model without locking it.""" - return self.locker.model() + return self.locker.model class ModelLoaderBase(ABC): @@ -89,9 +61,9 @@ class ModelLoaderBase(ABC): @abstractmethod def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ - Return a model given its key. + Return a model given its confguration. - Given a model key identified in the model configuration backend, + Given a model identified in the model configuration backend, return a ModelInfo object that can be used to retrieve the model. :param model_config: Model configuration, as returned by ModelConfigRecordStore @@ -115,34 +87,32 @@ class AnyModelLoader: # this tracks the loader subclasses _registry: Dict[str, Type[ModelLoaderBase]] = {} - @inject def __init__( self, - store: ModelRecordServiceBase, app_config: InvokeAIAppConfig, logger: Logger, ram_cache: ModelCacheBase, convert_cache: ModelConvertCacheBase, ): - """Store the provided ModelRecordServiceBase and empty the registry.""" - self._store = store + """Initialize AnyModelLoader with its dependencies.""" self._app_config = app_config self._logger = logger self._ram_cache = ram_cache self._convert_cache = convert_cache - def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """ - Return a model given its key. + @property + def ram_cache(self) -> ModelCacheBase: + """Return the RAM cache associated used by the loaders.""" + return self._ram_cache - Given a model key identified in the model configuration backend, - return a ModelInfo object that can be used to retrieve the model. + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType]=None) -> LoadedModel: + """ + Return a model given its configuration. :param key: model key, as known to the config backend :param submodel_type: an ModelType enum indicating the portion of the model to retrieve (e.g. ModelType.Vae) """ - model_config = self._store.get_model(key) implementation = self.__class__.get_implementation( base=model_config.base, type=model_config.type, format=model_config.format ) @@ -165,7 +135,7 @@ class AnyModelLoader: implementation = cls._registry.get(key1) or cls._registry.get(key2) if not implementation: raise NotImplementedError( - "No subclass of LoadedModel is registered for base={base}, type={type}, format={format}" + f"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}" ) return implementation @@ -176,18 +146,10 @@ class AnyModelLoader: """Define a decorator which registers the subclass of loader.""" def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: - print("Registering class", subclass.__name__) + print("DEBUG: Registering class", subclass.__name__) key = cls._to_registry_key(base, type, format) cls._registry[key] = subclass return subclass return decorator - -# in _init__.py will call something like -# def configure_loader_dependencies(binder): -# binder.bind(ModelRecordServiceBase, ApiDependencies.invoker.services.model_records, scope=singleton) -# binder.bind(InvokeAIAppConfig, ApiDependencies.invoker.services.configuration, scope=singleton) -# etc -# injector = Injector(configure_loader_dependencies) -# loader = injector.get(ModelFactory) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index eb2d432aaa..0b028235fd 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -8,15 +8,14 @@ from typing import Any, Dict, Optional, Tuple from diffusers import ModelMixin from diffusers.configuration_utils import ConfigMixin -from injector import inject from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType -from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init -from invokeai.backend.model_manager.ram_cache import ModelCacheBase, ModelLockerBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -35,7 +34,6 @@ class ConfigLoader(ConfigMixin): class ModelLoader(ModelLoaderBase): """Default implementation of ModelLoaderBase.""" - @inject # can inject instances of each of the classes in the call signature def __init__( self, app_config: InvokeAIAppConfig, @@ -87,18 +85,15 @@ class ModelLoader(ModelLoaderBase): def _convert_if_needed( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None ) -> Path: - if not self._needs_conversion(config): - return model_path + cache_path: Path = self._convert_cache.cache_path(config.key) + + if not self._needs_conversion(config, model_path, cache_path): + return cache_path if cache_path.exists() else model_path self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type)) - cache_path: Path = self._convert_cache.cache_path(config.key) - if cache_path.exists(): - return cache_path + return self._convert_model(config, model_path, cache_path) - self._convert_model(model_path, cache_path) - return cache_path - - def _needs_conversion(self, config: AnyModelConfig) -> bool: + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool: return False def _load_if_needed( @@ -133,7 +128,7 @@ class ModelLoader(ModelLoaderBase): variant=config.repo_variant if hasattr(config, "repo_variant") else None, ) - def _convert_model(self, model_path: Path, cache_path: Path) -> None: + def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: raise NotImplementedError def _load_model( diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py new file mode 100644 index 0000000000..776b9d8936 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -0,0 +1,5 @@ +"""Init file for RamCache.""" + +from .model_cache_base import ModelCacheBase +from .model_cache_default import ModelCache +_all__ = ['ModelCacheBase', 'ModelCache'] diff --git a/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py similarity index 77% rename from invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py rename to invokeai/backend/model_manager/load/model_cache/model_cache_base.py index cd80d1e78b..50b69d961c 100644 --- a/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -10,34 +10,41 @@ model will be cleared and (re)loaded from disk when next needed. from abc import ABC, abstractmethod from dataclasses import dataclass, field from logging import Logger -from typing import Dict, Optional +from typing import Dict, Optional, TypeVar, Generic import torch -from invokeai.backend.model_manager import SubModelType -from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase +from invokeai.backend.model_manager import AnyModel, SubModelType +class ModelLockerBase(ABC): + """Base class for the model locker used by the loader.""" + + @abstractmethod + def lock(self) -> AnyModel: + """Lock the contained model and move it into VRAM.""" + pass + + @abstractmethod + def unlock(self) -> None: + """Unlock the contained model, and remove it from VRAM.""" + pass + + @property + @abstractmethod + def model(self) -> AnyModel: + """Return the model.""" + pass + +T = TypeVar("T") @dataclass -class CacheStats(object): - """Data object to record statistics on cache hits/misses.""" - - hits: int = 0 # cache hits - misses: int = 0 # cache misses - high_watermark: int = 0 # amount of cache used - in_cache: int = 0 # number of models in cache - cleared: int = 0 # number of models cleared to make space - cache_size: int = 0 # total size of cache - loaded_model_sizes: Dict[str, int] = field(default_factory=dict) - - -@dataclass -class CacheRecord: +class CacheRecord(Generic[T]): """Elements of the cache.""" key: str - model: AnyModel + model: T size: int + loaded: bool = False _locks: int = 0 def lock(self) -> None: @@ -55,7 +62,7 @@ class CacheRecord: return self._locks > 0 -class ModelCacheBase(ABC): +class ModelCacheBase(ABC, Generic[T]): """Virtual base class for RAM model cache.""" @property @@ -76,8 +83,14 @@ class ModelCacheBase(ABC): """Return true if the cache is configured to lazily offload models in VRAM.""" pass + @property @abstractmethod - def offload_unlocked_models(self) -> None: + def max_cache_size(self) -> float: + """Return true if the cache is configured to lazily offload models in VRAM.""" + pass + + @abstractmethod + def offload_unlocked_models(self, size_required: int) -> None: """Offload from VRAM any models not actively in use.""" pass @@ -101,7 +114,7 @@ class ModelCacheBase(ABC): def put( self, key: str, - model: AnyModel, + model: T, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" @@ -134,11 +147,6 @@ class ModelCacheBase(ABC): """Get the total size of the models currently cached.""" pass - @abstractmethod - def get_stats(self) -> CacheStats: - """Return cache hit/miss/size statistics.""" - pass - @abstractmethod def print_cuda_stats(self) -> None: """Log debugging information on CUDA usage.""" diff --git a/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py similarity index 63% rename from invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py rename to invokeai/backend/model_manager/load/model_cache/model_cache_default.py index bd43e978c8..961f68a4be 100644 --- a/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -18,6 +18,7 @@ context. Use like this: """ +import gc import math import time from contextlib import suppress @@ -26,14 +27,14 @@ from typing import Any, Dict, List, Optional import torch -from invokeai.app.services.model_records import UnknownModelException from invokeai.backend.model_manager import SubModelType -from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase +from invokeai.backend.model_manager.load.load_base import AnyModel from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data -from invokeai.backend.model_manager.load.ram_cache.ram_cache_base import CacheRecord, CacheStats, ModelCacheBase from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger +from .model_cache_base import CacheRecord, ModelCacheBase +from .model_locker import ModelLockerBase, ModelLocker if choose_torch_device() == torch.device("mps"): from torch import mps @@ -52,7 +53,7 @@ GIG = 1073741824 MB = 2**20 -class ModelCache(ModelCacheBase): +class ModelCache(ModelCacheBase[AnyModel]): """Implementation of ModelCacheBase.""" def __init__( @@ -92,62 +93,9 @@ class ModelCache(ModelCacheBase): self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage - # used for stats collection - self.stats = None - - self._cached_models: Dict[str, CacheRecord] = {} + self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cache_stack: List[str] = [] - class ModelLocker(ModelLockerBase): - """Internal class that mediates movement in and out of GPU.""" - - def __init__(self, cache: ModelCacheBase, cache_entry: CacheRecord): - """ - Initialize the model locker. - - :param cache: The ModelCache object - :param cache_entry: The entry in the model cache - """ - self._cache = cache - self._cache_entry = cache_entry - - @property - def model(self) -> AnyModel: - """Return the model without moving it around.""" - return self._cache_entry.model - - def lock(self) -> Any: - """Move the model into the execution device (GPU) and lock it.""" - if not hasattr(self.model, "to"): - return self.model - - # NOTE that the model has to have the to() method in order for this code to move it into GPU! - self._cache_entry.lock() - - try: - if self._cache.lazy_offloading: - self._cache.offload_unlocked_models() - - self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) - - self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") - self._cache.print_cuda_stats() - - except Exception: - self._cache_entry.unlock() - raise - return self.model - - def unlock(self) -> None: - """Call upon exit from context.""" - if not hasattr(self.model, "to"): - return - - self._cache_entry.unlock() - if not self._cache.lazy_offloading: - self._cache.offload_unlocked_models() - self._cache.print_cuda_stats() - @property def logger(self) -> Logger: """Return the logger used by the cache.""" @@ -168,6 +116,11 @@ class ModelCache(ModelCacheBase): """Return the exection device (e.g. "cuda" for VRAM).""" return self._execution_device + @property + def max_cache_size(self) -> float: + """Return the cap on cache size.""" + return self._max_cache_size + def cache_size(self) -> int: """Get the total size of the models currently cached.""" total = 0 @@ -207,18 +160,18 @@ class ModelCache(ModelCacheBase): """ Retrieve model using key and optional submodel_type. - This may return an UnknownModelException if the model is not in the cache. + This may return an IndexError if the model is not in the cache. """ key = self._make_cache_key(key, submodel_type) if key not in self._cached_models: - raise UnknownModelException + raise IndexError(f"The model with key {key} is not in the cache.") # this moves the entry to the top (right end) of the stack with suppress(Exception): self._cache_stack.remove(key) self._cache_stack.append(key) cache_entry = self._cached_models[key] - return self.ModelLocker( + return ModelLocker( cache=self, cache_entry=cache_entry, ) @@ -234,19 +187,19 @@ class ModelCache(ModelCacheBase): else: return model_key - def offload_unlocked_models(self) -> None: + def offload_unlocked_models(self, size_required: int) -> None: """Move any unused models from VRAM.""" reserved = self._max_vram_cache_size * GIG - vram_in_use = torch.cuda.memory_allocated() - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") + vram_in_use = torch.cuda.memory_allocated() + size_required + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB") for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): if vram_in_use <= reserved: break if not cache_entry.locked: self.move_model_to_device(cache_entry, self.storage_device) - - vram_in_use = torch.cuda.memory_allocated() - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") + cache_entry.loaded = False + vram_in_use = torch.cuda.memory_allocated() + size_required + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM now available for models; max allowed={(reserved/GIG):.2f}GB") torch.cuda.empty_cache() if choose_torch_device() == torch.device("mps"): @@ -305,28 +258,111 @@ class ModelCache(ModelCacheBase): def print_cuda_stats(self) -> None: """Log CUDA diagnostics.""" vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) - ram = "%4.2fG" % self.cache_size() + ram = "%4.2fG" % (self.cache_size() / GIG) - cached_models = 0 - loaded_models = 0 - locked_models = 0 + in_ram_models = 0 + in_vram_models = 0 + locked_in_vram_models = 0 for cache_record in self._cached_models.values(): - cached_models += 1 assert hasattr(cache_record.model, "device") - if cache_record.model.device is self.storage_device: - loaded_models += 1 + if cache_record.model.device == self.storage_device: + in_ram_models += 1 + else: + in_vram_models += 1 if cache_record.locked: - locked_models += 1 + locked_in_vram_models += 1 self.logger.debug( - f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ =" - f" {cached_models}/{loaded_models}/{locked_models}" + f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" + f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" ) - def get_stats(self) -> CacheStats: - """Return cache hit/miss/size statistics.""" - raise NotImplementedError - - def make_room(self, size: int) -> None: + def make_room(self, model_size: int) -> None: """Make enough room in the cache to accommodate a new model of indicated size.""" - raise NotImplementedError + # calculate how much memory this model will require + # multiplier = 2 if self.precision==torch.float32 else 1 + bytes_needed = model_size + maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes + current_size = self.cache_size() + + if current_size + bytes_needed > maximum_size: + self.logger.debug( + f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional" + f" {(bytes_needed/GIG):.2f} GB" + ) + + self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}") + + pos = 0 + models_cleared = 0 + while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): + model_key = self._cache_stack[pos] + cache_entry = self._cached_models[model_key] + + refs = sys.getrefcount(cache_entry.model) + + # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly + # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way: + # https://docs.python.org/3/library/gc.html#gc.get_referrers + + # manualy clear local variable references of just finished function calls + # for some reason python don't want to collect it even by gc.collect() immidiately + if refs > 2: + while True: + cleared = False + for referrer in gc.get_referrers(cache_entry.model): + if type(referrer).__name__ == "frame": + # RuntimeError: cannot clear an executing frame + with suppress(RuntimeError): + referrer.clear() + cleared = True + # break + + # repeat if referrers changes(due to frame clear), else exit loop + if cleared: + gc.collect() + else: + break + + device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None + self.logger.debug( + f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," + f" refs: {refs}" + ) + + # Expected refs: + # 1 from cache_entry + # 1 from getrefcount function + # 1 from onnx runtime object + if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): + self.logger.debug( + f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" + ) + current_size -= cache_entry.size + models_cleared += 1 + del self._cache_stack[pos] + del self._cached_models[model_key] + del cache_entry + + else: + pos += 1 + + if models_cleared > 0: + # There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but + # there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost + # is high even if no garbage gets collected.) + # + # Calling gc.collect(...) when a model is cleared seems like a good middle-ground: + # - If models had to be cleared, it's a signal that we are close to our memory limit. + # - If models were cleared, there's a good chance that there's a significant amount of garbage to be + # collected. + # + # Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up + # immediately when their reference count hits 0. + gc.collect() + + torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() + + self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py new file mode 100644 index 0000000000..506d012949 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -0,0 +1,59 @@ +""" +Base class and implementation of a class that moves models in and out of VRAM. +""" + +from abc import ABC, abstractmethod +from invokeai.backend.model_manager import AnyModel +from .model_cache_base import ModelLockerBase, ModelCacheBase, CacheRecord + +class ModelLocker(ModelLockerBase): + """Internal class that mediates movement in and out of GPU.""" + + def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]): + """ + Initialize the model locker. + + :param cache: The ModelCache object + :param cache_entry: The entry in the model cache + """ + self._cache = cache + self._cache_entry = cache_entry + + @property + def model(self) -> AnyModel: + """Return the model without moving it around.""" + return self._cache_entry.model + + def lock(self) -> AnyModel: + """Move the model into the execution device (GPU) and lock it.""" + if not hasattr(self.model, "to"): + return self.model + + # NOTE that the model has to have the to() method in order for this code to move it into GPU! + self._cache_entry.lock() + + try: + if self._cache.lazy_offloading: + self._cache.offload_unlocked_models(self._cache_entry.size) + + self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) + self._cache_entry.loaded = True + + self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") + self._cache.print_cuda_stats() + + except Exception: + self._cache_entry.unlock() + raise + return self.model + + def unlock(self) -> None: + """Call upon exit from context.""" + if not hasattr(self.model, "to"): + return + + self._cache_entry.unlock() + if not self._cache.lazy_offloading: + self._cache.offload_unlocked_models(self._cache_entry.size) + self._cache.print_cuda_stats() + diff --git a/invokeai/backend/model_manager/load/model_loaders/__init__.py b/invokeai/backend/model_manager/load/model_loaders/__init__.py new file mode 100644 index 0000000000..962cba5481 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/__init__.py @@ -0,0 +1,3 @@ +""" +Init file for model_loaders. +""" diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py new file mode 100644 index 0000000000..6f21c3d090 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for VAE model loading in InvokeAI.""" + +from pathlib import Path +from typing import Optional + +import torch +import safetensors +from omegaconf import OmegaConf, DictConfig +from invokeai.backend.util.devices import torch_dtype +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) +@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) +class VaeDiffusersModel(ModelLoader): + """Class to load VAE models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise Exception("There are no submodels in VAEs") + vae_class = self._get_hf_load_class(model_path) + variant = model_variant.value if model_variant else None + result: AnyModel = vae_class.from_pretrained( + model_path, torch_dtype=self._torch_dtype, variant=variant + ) # type: ignore + return result + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: + print(f'DEBUG: last_modified={config.last_modified}') + print(f'DEBUG: cache_path={(dest_path / "config.json").stat().st_mtime}') + print(f'DEBUG: model_path={model_path.stat().st_mtime}') + if config.format != ModelFormat.Checkpoint: + return False + elif dest_path.exists() \ + and (dest_path / "config.json").stat().st_mtime >= config.last_modified \ + and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime: + return False + else: + return True + + def _convert_model(self, + config: AnyModelConfig, + weights_path: Path, + output_path: Path + ) -> Path: + if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: + raise Exception(f"Vae conversion not supported for model type: {config.base}") + else: + config_file = 'v1-inference.yaml' if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" + + if weights_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(weights_path, device="cpu") + else: + checkpoint = torch.load(weights_path, map_location="cpu") + + dtype = torch_dtype() + + # sometimes weights are hidden under "state_dict", and sometimes not + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file) + assert isinstance(ckpt_config, DictConfig) + + print(f'DEBUG: CONVERTIGN') + vae_model = convert_ldm_vae_to_diffusers( + checkpoint=checkpoint, + vae_config=ckpt_config, + image_size=512, + ) + vae_model.to(dtype) # set precision appropriately + vae_model.save_pretrained(output_path, safe_serialization=True, torch_dtype=dtype) + return output_path + diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 18407cbca2..7c27e66472 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -48,6 +48,9 @@ def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int: def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int: """Estimate the size of a model on disk in bytes.""" + if model_path.is_file(): + return model_path.stat().st_size + if subfolder is not None: model_path = model_path / subfolder diff --git a/invokeai/backend/model_manager/load/ram_cache/__init__.py b/invokeai/backend/model_manager/load/ram_cache/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/invokeai/backend/model_manager/load/vae.py b/invokeai/backend/model_manager/load/vae.py deleted file mode 100644 index a6cbe241e1..0000000000 --- a/invokeai/backend/model_manager/load/vae.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team -"""Class for VAE model loading in InvokeAI.""" - -from pathlib import Path -from typing import Dict, Optional - -import torch - -from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader - - -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) -class VaeDiffusersModel(ModelLoader): - """Class to load VAE models.""" - - def _load_model( - self, - model_path: Path, - model_variant: Optional[ModelRepoVariant] = None, - submodel_type: Optional[SubModelType] = None, - ) -> Dict[str, torch.Tensor]: - if submodel_type is not None: - raise Exception("There are no submodels in VAEs") - vae_class = self._get_hf_load_class(model_path) - variant = model_variant.value if model_variant else "" - result: Dict[str, torch.Tensor] = vae_class.from_pretrained( - model_path, torch_dtype=self._torch_dtype, variant=variant - ) # type: ignore - return result diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index 87ae1480f5..0164dffe30 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -12,6 +12,14 @@ from .devices import ( # noqa: F401 torch_dtype, ) from .logging import InvokeAILogger -from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401 +from .util import ( # TO DO: Clean this up; remove the unused symbols + GIG, + Chdir, + ask_user, # noqa + directory_size, + download_with_resume, + instantiate_from_config, # noqa + url_attachment_name, # noqa + ) -__all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"] +__all__ = ["GIG", "directory_size","Chdir", "download_with_resume", "InvokeAILogger", "choose_precision", "choose_torch_device"] diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index d6d3ad727f..ad3f4e139a 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import nullcontext -from typing import Union +from typing import Union, Optional import torch from torch import autocast @@ -43,7 +43,8 @@ def choose_precision(device: torch.device) -> str: return "float32" -def torch_dtype(device: torch.device) -> torch.dtype: +def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype: + device = device or choose_torch_device() precision = choose_precision(device) if precision == "float16": return torch.float16 diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 13751e2770..6589aa7278 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -24,6 +24,20 @@ import invokeai.backend.util.logging as logger from .devices import torch_dtype +# actual size of a gig +GIG = 1073741824 + +def directory_size(directory: Path) -> int: + """ + Return the aggregate size of all files in a directory (bytes). + """ + sum = 0 + for root, dirs, files in os.walk(directory): + for f in files: + sum += Path(root, f).stat().st_size + for d in dirs: + sum += Path(root, d).stat().st_size + return sum def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height)