mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Load single-file checkpoints directly without conversion (#6510)
* use model_class.load_singlefile() instead of converting; works, but performance is poor * adjust the convert api - not right just yet * working, needs sql migrator update * rename migration_11 before conflict merge with main * Update invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Update invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * implement lightweight version-by-version config migration * simplified config schema migration code * associate sdxl config with sdxl VAEs * remove use of original_config_file in load_single_file() --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
This commit is contained in:
parent
aba16085a5
commit
3e0fb45dd7
@ -3,9 +3,9 @@
|
||||
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
@ -19,7 +19,6 @@ from typing_extensions import Annotated
|
||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordChanges,
|
||||
UnknownModelException,
|
||||
@ -30,7 +29,6 @@ from invokeai.backend.model_manager.config import (
|
||||
MainCheckpointConfig,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
@ -174,18 +172,6 @@ async def get_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
# @model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
# async def list_model_summary(
|
||||
# page: int = Query(default=0, description="The page to get"),
|
||||
# per_page: int = Query(default=10, description="The number of models per page"),
|
||||
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
# ) -> PaginatedResults[ModelSummary]:
|
||||
# """Gets a page of model summary data."""
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
# return results
|
||||
|
||||
|
||||
class FoundModel(BaseModel):
|
||||
path: str = Field(description="Path to the model")
|
||||
is_installed: bool = Field(description="Whether or not the model is already installed")
|
||||
@ -746,39 +732,36 @@ async def convert_model(
|
||||
logger.error(f"The model with key {key} is not a main checkpoint model.")
|
||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||
|
||||
# loading the model will convert it into a cached diffusers file
|
||||
try:
|
||||
cc_size = loader.convert_cache.max_size
|
||||
if cc_size == 0: # temporary set the convert cache to a positive number so that cached model is written
|
||||
loader._convert_cache.max_size = 1.0
|
||||
loader.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
finally:
|
||||
loader._convert_cache.max_size = cc_size
|
||||
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
|
||||
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem
|
||||
converted_model = loader.load_model(model_config)
|
||||
# write the converted file to the convert path
|
||||
raw_model = converted_model.model
|
||||
assert hasattr(raw_model, "save_pretrained")
|
||||
raw_model.save_pretrained(convert_path)
|
||||
assert convert_path.exists()
|
||||
|
||||
# Get the path of the converted model from the loader
|
||||
cache_path = loader.convert_cache.cache_path(key)
|
||||
assert cache_path.exists()
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
original_name = model_config.name
|
||||
model_config.name = f"{original_name}.DELETE"
|
||||
changes = ModelRecordChanges(name=model_config.name)
|
||||
store.update_model(key, changes=changes)
|
||||
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
original_name = model_config.name
|
||||
model_config.name = f"{original_name}.DELETE"
|
||||
changes = ModelRecordChanges(name=model_config.name)
|
||||
store.update_model(key, changes=changes)
|
||||
|
||||
# install the diffusers
|
||||
try:
|
||||
new_key = installer.install_path(
|
||||
cache_path,
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"hash": model_config.hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
# install the diffusers
|
||||
try:
|
||||
new_key = installer.install_path(
|
||||
convert_path,
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"hash": model_config.hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
store.update_model(key, changes=ModelRecordChanges(name=original_name))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
# Update the model image if the model had one
|
||||
try:
|
||||
@ -791,8 +774,8 @@ async def convert_model(
|
||||
# delete the original safetensors file
|
||||
installer.delete(key)
|
||||
|
||||
# delete the cached version
|
||||
shutil.rmtree(cache_path)
|
||||
# delete the temporary directory
|
||||
# shutil.rmtree(cache_path)
|
||||
|
||||
# return the config record for the new diffusers directory
|
||||
new_config = store.get_model(new_key)
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import locale
|
||||
import os
|
||||
import re
|
||||
@ -25,14 +26,13 @@ DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_RAM_CACHE = 10.0
|
||||
DEFAULT_VRAM_CACHE = 0.25
|
||||
DEFAULT_CONVERT_CACHE = 20.0
|
||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||
CONFIG_SCHEMA_VERSION = "4.0.1"
|
||||
CONFIG_SCHEMA_VERSION = "4.0.2"
|
||||
|
||||
|
||||
def get_default_ram_cache_size() -> float:
|
||||
@ -85,7 +85,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
log_tokenization: Enable logging of parsed prompt tokens.
|
||||
patchmatch: Enable patchmatch inpaint code.
|
||||
models_dir: Path to the models directory.
|
||||
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
|
||||
convert_cache_dir: Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions).
|
||||
download_cache_dir: Path to the directory that contains dynamically downloaded models.
|
||||
legacy_conf_dir: Path to directory of legacy checkpoint config files.
|
||||
db_dir: Path to InvokeAI databases directory.
|
||||
@ -102,7 +102,6 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
profiles_dir: Path to profiles output directory.
|
||||
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
|
||||
vram: Amount of VRAM reserved for model storage (GB).
|
||||
convert_cache: Maximum size of on-disk converted models cache (GB).
|
||||
lazy_offload: Keep models in VRAM until their space is needed.
|
||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
@ -148,7 +147,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
|
||||
# PATHS
|
||||
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
|
||||
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
|
||||
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions).")
|
||||
download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
|
||||
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
|
||||
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
||||
@ -170,9 +169,8 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
|
||||
|
||||
# CACHE
|
||||
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
|
||||
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
|
||||
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).")
|
||||
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
|
||||
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
|
||||
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
|
||||
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
|
||||
|
||||
@ -357,14 +355,14 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
|
||||
return (init_settings,)
|
||||
|
||||
|
||||
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate a v3 config dictionary to a current config object.
|
||||
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate a v3 config dictionary to a v4.0.0.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v3 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
An `InvokeAIAppConfig` config dict.
|
||||
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
@ -398,32 +396,41 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
# When migrating the config file, we should not include currently-set environment variables.
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
|
||||
return config
|
||||
parsed_config_dict["schema_version"] = "4.0.0"
|
||||
return parsed_config_dict
|
||||
|
||||
|
||||
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate v4.0.0 config dictionary to a current config object.
|
||||
def migrate_v4_0_0_to_4_0_1_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate v4.0.0 config dictionary to a v4.0.1 config dictionary
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
A config dict with the settings migrated to v4.0.1.
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, v in config_dict.items():
|
||||
# autocast was removed from precision in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
else:
|
||||
parsed_config_dict[k] = v
|
||||
if k == "schema_version":
|
||||
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
return config
|
||||
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
|
||||
# precision "autocast" was replaced by "auto" in v4.0.1
|
||||
if parsed_config_dict.get("precision") == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
parsed_config_dict["schema_version"] = "4.0.1"
|
||||
return parsed_config_dict
|
||||
|
||||
|
||||
def migrate_v4_0_1_to_4_0_2_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate v4.0.1 config dictionary to a v4.0.2 config dictionary.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.1 config file.
|
||||
|
||||
Returns:
|
||||
An config dict with the settings migrated to v4.0.2.
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
|
||||
# convert_cache was removed in 4.0.2
|
||||
parsed_config_dict.pop("convert_cache", None)
|
||||
parsed_config_dict["schema_version"] = "4.0.2"
|
||||
return parsed_config_dict
|
||||
|
||||
|
||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
@ -437,27 +444,31 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""
|
||||
assert config_path.suffix == ".yaml"
|
||||
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
loaded_config_dict = yaml.safe_load(file)
|
||||
loaded_config_dict: dict[str, Any] = yaml.safe_load(file)
|
||||
|
||||
assert isinstance(loaded_config_dict, dict)
|
||||
|
||||
migrated = False
|
||||
if "InvokeAI" in loaded_config_dict:
|
||||
# This is a v3 config file, attempt to migrate it
|
||||
migrated = True
|
||||
loaded_config_dict = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||
if loaded_config_dict["schema_version"] == "4.0.0":
|
||||
migrated = True
|
||||
loaded_config_dict = migrate_v4_0_0_to_4_0_1_config_dict(loaded_config_dict)
|
||||
if loaded_config_dict["schema_version"] == "4.0.1":
|
||||
migrated = True
|
||||
loaded_config_dict = migrate_v4_0_1_to_4_0_2_config_dict(loaded_config_dict)
|
||||
|
||||
if migrated:
|
||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||
try:
|
||||
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
|
||||
migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||
# load and write without environment variables
|
||||
migrated_config = DefaultInvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||
migrated_config.write_file(config_path)
|
||||
except Exception as e:
|
||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
||||
migrated_config.write_file(config_path)
|
||||
return migrated_config
|
||||
|
||||
if loaded_config_dict["schema_version"] == "4.0.0":
|
||||
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
||||
loaded_config_dict.write_file(config_path)
|
||||
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||
|
@ -7,7 +7,6 @@ from typing import Callable, Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
|
||||
@ -28,11 +27,6 @@ class ModelLoadServiceBase(ABC):
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
|
@ -17,7 +17,6 @@ from invokeai.backend.model_manager.load import (
|
||||
ModelLoaderRegistry,
|
||||
ModelLoaderRegistryBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@ -33,7 +32,6 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
|
||||
):
|
||||
"""Initialize the model load service."""
|
||||
@ -42,7 +40,6 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._logger = logger
|
||||
self._app_config = app_config
|
||||
self._ram_cache = ram_cache
|
||||
self._convert_cache = convert_cache
|
||||
self._registry = registry
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
@ -53,11 +50,6 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
@ -76,7 +68,6 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
app_config=self._app_config,
|
||||
logger=self._logger,
|
||||
ram_cache=self._ram_cache,
|
||||
convert_cache=self._convert_cache,
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
if hasattr(self, "_invoker"):
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
@ -86,11 +86,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
logger=logger,
|
||||
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||
)
|
||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||
loader = ModelLoadService(
|
||||
app_config=app_config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry,
|
||||
)
|
||||
installer = ModelInstallService(
|
||||
|
@ -14,6 +14,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -45,6 +46,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_9())
|
||||
migrator.register_migration(build_migration_10())
|
||||
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_12(app_config=config))
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -0,0 +1,35 @@
|
||||
import shutil
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration12Callback:
|
||||
def __init__(self, app_config: InvokeAIAppConfig) -> None:
|
||||
self._app_config = app_config
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._remove_model_convert_cache_dir()
|
||||
|
||||
def _remove_model_convert_cache_dir(self) -> None:
|
||||
"""
|
||||
Removes unused model convert cache directory
|
||||
"""
|
||||
convert_cache = self._app_config.convert_cache_path
|
||||
shutil.rmtree(convert_cache, ignore_errors=True)
|
||||
|
||||
|
||||
def build_migration_12(app_config: InvokeAIAppConfig) -> Migration:
|
||||
"""
|
||||
Build the migration from database version 11 to 12.
|
||||
|
||||
This migration removes the now-unused model convert cache directory.
|
||||
"""
|
||||
migration_12 = Migration(
|
||||
from_version=11,
|
||||
to_version=12,
|
||||
callback=Migration12Callback(app_config),
|
||||
)
|
||||
|
||||
return migration_12
|
@ -24,6 +24,7 @@ import time
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, TypeAlias, Union
|
||||
|
||||
import diffusers
|
||||
import torch
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
@ -37,7 +38,7 @@ from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline]
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
|
@ -1,83 +0,0 @@
|
||||
# Adapted for use in InvokeAI by Lincoln Stein, July 2023
|
||||
#
|
||||
"""Conversion script for the Stable Diffusion checkpoints."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
convert_ldm_vae_checkpoint,
|
||||
create_vae_diffusers_config,
|
||||
download_controlnet_from_original_ckpt,
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from . import AnyModel
|
||||
|
||||
|
||||
def convert_ldm_vae_to_diffusers(
|
||||
checkpoint: torch.Tensor | dict[str, torch.Tensor],
|
||||
vae_config: DictConfig,
|
||||
image_size: int,
|
||||
dump_path: Optional[Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
) -> AutoencoderKL:
|
||||
"""Convert a checkpoint-style VAE into a Diffusers VAE"""
|
||||
vae_config = create_vae_diffusers_config(vae_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
vae.to(precision)
|
||||
|
||||
if dump_path:
|
||||
vae.save_pretrained(dump_path, safe_serialization=True)
|
||||
|
||||
return vae
|
||||
|
||||
|
||||
def convert_ckpt_to_diffusers(
|
||||
checkpoint_path: str | Path,
|
||||
dump_path: Optional[str | Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
use_safetensors: bool = True,
|
||||
**kwargs,
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
model to be written.
|
||||
"""
|
||||
pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
|
||||
pipe = pipe.to(precision)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
if dump_path:
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=use_safetensors,
|
||||
)
|
||||
return pipe
|
||||
|
||||
|
||||
def convert_controlnet_to_diffusers(
|
||||
checkpoint_path: Path,
|
||||
dump_path: Optional[Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
**kwargs,
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Takes all the arguments of download_controlnet_from_original_ckpt(),
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
model to be written.
|
||||
"""
|
||||
pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
|
||||
pipe = pipe.to(precision)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
if dump_path:
|
||||
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
return pipe
|
@ -6,7 +6,6 @@ Init file for the model loader.
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
|
||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
|
||||
from .load_default import ModelLoader
|
||||
from .model_cache.model_cache_default import ModelCache
|
||||
@ -21,7 +20,6 @@ __all__ = [
|
||||
"LoadedModel",
|
||||
"LoadedModelWithoutConfig",
|
||||
"ModelCache",
|
||||
"ModelConvertCache",
|
||||
"ModelLoaderBase",
|
||||
"ModelLoader",
|
||||
"ModelLoaderRegistryBase",
|
||||
|
@ -1,4 +0,0 @@
|
||||
from .convert_cache_base import ModelConvertCacheBase
|
||||
from .convert_cache_default import ModelConvertCache
|
||||
|
||||
__all__ = ["ModelConvertCacheBase", "ModelConvertCache"]
|
@ -1,28 +0,0 @@
|
||||
"""
|
||||
Disk-based converted model cache.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ModelConvertCacheBase(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_size(self) -> float:
|
||||
"""Return the maximum size of this cache directory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def make_room(self, size: float) -> None:
|
||||
"""
|
||||
Make sufficient room in the cache directory for a model of max_size.
|
||||
|
||||
:param size: Size required (GB)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_path(self, key: str) -> Path:
|
||||
"""Return the path for a model with the indicated key."""
|
||||
pass
|
@ -1,83 +0,0 @@
|
||||
"""
|
||||
Placeholder for convert cache implementation.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.backend.util import GIG, directory_size
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.util import safe_filename
|
||||
|
||||
from .convert_cache_base import ModelConvertCacheBase
|
||||
|
||||
|
||||
class ModelConvertCache(ModelConvertCacheBase):
|
||||
def __init__(self, cache_path: Path, max_size: float = 10.0):
|
||||
"""Initialize the convert cache with the base directory and a limit on its maximum size (in GBs)."""
|
||||
if not cache_path.exists():
|
||||
cache_path.mkdir(parents=True)
|
||||
self._cache_path = cache_path
|
||||
self._max_size = max_size
|
||||
|
||||
# adjust cache size at startup in case it has been changed
|
||||
if self._cache_path.exists():
|
||||
self.make_room(0.0)
|
||||
|
||||
@property
|
||||
def max_size(self) -> float:
|
||||
"""Return the maximum size of this cache directory (GB)."""
|
||||
return self._max_size
|
||||
|
||||
@max_size.setter
|
||||
def max_size(self, value: float) -> None:
|
||||
"""Set the maximum size of this cache directory (GB)."""
|
||||
self._max_size = value
|
||||
|
||||
def cache_path(self, key: str) -> Path:
|
||||
"""Return the path for a model with the indicated key."""
|
||||
key = safe_filename(self._cache_path, key)
|
||||
return self._cache_path / key
|
||||
|
||||
def make_room(self, size: float) -> None:
|
||||
"""
|
||||
Make sufficient room in the cache directory for a model of max_size.
|
||||
|
||||
:param size: Size required (GB)
|
||||
"""
|
||||
size_needed = directory_size(self._cache_path) + size
|
||||
max_size = int(self.max_size) * GIG
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
if size_needed <= max_size:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Convert cache has gotten too large {(size_needed / GIG):4.2f} > {(max_size / GIG):4.2f}G.. Trimming."
|
||||
)
|
||||
|
||||
# For this to work, we make the assumption that the directory contains
|
||||
# a 'model_index.json', 'unet/config.json' file, or a 'config.json' file at top level.
|
||||
# This should be true for any diffusers model.
|
||||
def by_atime(path: Path) -> float:
|
||||
for config in ["model_index.json", "unet/config.json", "config.json"]:
|
||||
sentinel = path / config
|
||||
if sentinel.exists():
|
||||
return sentinel.stat().st_atime
|
||||
|
||||
# no sentinel file found! - pick the most recent file in the directory
|
||||
try:
|
||||
atimes = sorted([x.stat().st_atime for x in path.iterdir() if x.is_file()], reverse=True)
|
||||
return atimes[0]
|
||||
except IndexError:
|
||||
return 0.0
|
||||
|
||||
# sort by last access time - least accessed files will be at the end
|
||||
lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True)
|
||||
logger.debug(f"cached models in descending atime order: {lru_models}")
|
||||
while size_needed > max_size and len(lru_models) > 0:
|
||||
next_victim = lru_models.pop()
|
||||
victim_size = directory_size(next_victim)
|
||||
logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB")
|
||||
shutil.rmtree(next_victim)
|
||||
size_needed -= victim_size
|
@ -18,7 +18,6 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
|
||||
|
||||
@ -112,7 +111,6 @@ class ModelLoaderBase(ABC):
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
pass
|
||||
@ -138,12 +136,6 @@ class ModelLoaderBase(ABC):
|
||||
"""Return size in bytes of the model, calculated before loading."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the convert cache associated with this loader."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
|
@ -12,8 +12,7 @@ from invokeai.backend.model_manager import (
|
||||
InvalidModelConfigException,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
@ -30,13 +29,11 @@ class ModelLoader(ModelLoaderBase):
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
self._app_config = app_config
|
||||
self._logger = logger
|
||||
self._ram_cache = ram_cache
|
||||
self._convert_cache = convert_cache
|
||||
self._torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
@ -50,23 +47,15 @@ class ModelLoader(ModelLoaderBase):
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
if model_config.type is ModelType.Main and not submodel_type:
|
||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||
|
||||
model_path = self._get_model_path(model_config)
|
||||
|
||||
if not model_path.exists():
|
||||
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
|
||||
|
||||
with skip_torch_weight_init():
|
||||
locker = self._convert_and_load(model_config, model_path, submodel_type)
|
||||
locker = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the convert cache associated with this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
@ -76,20 +65,14 @@ class ModelLoader(ModelLoaderBase):
|
||||
model_base = self._app_config.models_path
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _convert_and_load(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> ModelLockerBase:
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
cache_path: Path = self._convert_cache.cache_path(str(model_path))
|
||||
if self._needs_conversion(config, model_path, cache_path):
|
||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||
else:
|
||||
config.path = str(cache_path) if cache_path.exists() else str(self._get_model_path(config))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
config.path = str(self._get_model_path(config))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
config.key,
|
||||
@ -113,28 +96,6 @@ class ModelLoader(ModelLoaderBase):
|
||||
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
|
||||
)
|
||||
|
||||
def _do_convert(
|
||||
self, config: AnyModelConfig, model_path: Path, cache_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> AnyModel:
|
||||
self.convert_cache.make_room(calc_model_size_by_fs(model_path))
|
||||
pipeline = self._convert_model(config, model_path, cache_path if self.convert_cache.max_size > 0 else None)
|
||||
if submodel_type:
|
||||
# Proactively load the various submodels into the RAM cache so that we don't have to re-convert
|
||||
# the entire pipeline every time a new submodel is needed.
|
||||
for subtype in SubModelType:
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
raise NotImplementedError
|
||||
|
||||
# This needs to be implemented in the subclass
|
||||
def _load_model(
|
||||
self,
|
||||
|
@ -1,9 +1,10 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for ControlNet model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import ControlNetModel
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@ -11,8 +12,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
from invokeai.backend.model_manager.config import ControlNetCheckpointConfig, SubModelType
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
@ -23,36 +23,15 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
class ControlNetLoader(GenericDiffusersLoader):
|
||||
"""Class to load ControlNet models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
image_size = (
|
||||
512
|
||||
if config.base == BaseModelType.StableDiffusion1
|
||||
else 768
|
||||
if config.base == BaseModelType.StableDiffusion2
|
||||
else 1024
|
||||
)
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
with open(self._app_config.legacy_conf_path / config.config_path, "r") as config_stream:
|
||||
result = convert_controlnet_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
original_config_file=config_stream,
|
||||
image_size=image_size,
|
||||
precision=self._torch_dtype,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, ControlNetCheckpointConfig):
|
||||
return ControlNetModel.from_single_file(
|
||||
config.path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
return super()._load_model(config, submodel_type)
|
||||
|
@ -15,7 +15,6 @@ from invokeai.backend.model_manager import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
@ -32,10 +31,9 @@ class LoRALoader(ModelLoader):
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
super().__init__(app_config, logger, ram_cache, convert_cache)
|
||||
super().__init__(app_config, logger, ram_cache)
|
||||
self._model_base: Optional[BaseModelType] = None
|
||||
|
||||
def _load_model(
|
||||
|
@ -4,22 +4,28 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import (
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
DiffusersConfigBase,
|
||||
MainCheckpointConfig,
|
||||
ModelVariantType,
|
||||
)
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
@ -48,8 +54,12 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
if isinstance(config, CheckpointConfigBase):
|
||||
return self._load_from_singlefile(config, submodel_type)
|
||||
|
||||
if submodel_type is None:
|
||||
raise Exception("A submodel type must be provided when loading main pipelines.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
|
||||
@ -71,46 +81,58 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
return result
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
load_classes = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: StableDiffusionPipeline,
|
||||
ModelVariantType.Inpaint: StableDiffusionInpaintPipeline,
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: StableDiffusionPipeline,
|
||||
ModelVariantType.Inpaint: StableDiffusionInpaintPipeline,
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
|
||||
},
|
||||
}
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
base = config.base
|
||||
|
||||
try:
|
||||
load_class = load_classes[config.base][config.variant]
|
||||
except KeyError as e:
|
||||
raise Exception(f"No diffusers pipeline known for base={config.base}, variant={config.variant}") from e
|
||||
prediction_type = config.prediction_type.value
|
||||
upcast_attention = config.upcast_attention
|
||||
image_size = (
|
||||
1024
|
||||
if base == BaseModelType.StableDiffusionXL
|
||||
else 768
|
||||
if config.prediction_type == SchedulerPredictionType.VPrediction and base == BaseModelType.StableDiffusion2
|
||||
else 512
|
||||
)
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
# Without SilenceWarnings we get log messages like this:
|
||||
# site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
|
||||
# warnings.warn(
|
||||
# Some weights of the model checkpoint were not used when initializing CLIPTextModel:
|
||||
# ['text_model.embeddings.position_ids']
|
||||
# Some weights of the model checkpoint were not used when initializing CLIPTextModelWithProjection:
|
||||
# ['text_model.embeddings.position_ids']
|
||||
|
||||
loaded_model = convert_ckpt_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
model_type=self.model_base_to_model_type[base],
|
||||
original_config_file=self._app_config.legacy_conf_path / config.config_path,
|
||||
extract_ema=True,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
precision=self._torch_dtype,
|
||||
prediction_type=prediction_type,
|
||||
image_size=image_size,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=False,
|
||||
num_in_channels=VARIANT_TO_IN_CHANNEL_MAP[config.variant],
|
||||
)
|
||||
return loaded_model
|
||||
with SilenceWarnings():
|
||||
pipeline = load_class.from_single_file(
|
||||
config.path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
prediction_type=prediction_type,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=False,
|
||||
)
|
||||
|
||||
if not submodel_type:
|
||||
return pipeline
|
||||
|
||||
# Proactively load the various submodels into the RAM cache so that we don't have to re-load
|
||||
# the entire pipeline every time a new submodel is needed.
|
||||
for subtype in SubModelType:
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||
return getattr(pipeline, submodel_type.value)
|
||||
|
@ -1,12 +1,9 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for VAE model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
@ -14,8 +11,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import AnyModel, CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
from invokeai.backend.model_manager.config import AnyModel, SubModelType, VAECheckpointConfig
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
@ -26,39 +22,15 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
class VAELoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, VAECheckpointConfig):
|
||||
return AutoencoderKL.from_single_file(
|
||||
config.path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
)
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = self._app_config.legacy_conf_path / config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
ckpt_config = OmegaConf.load(config_file)
|
||||
assert isinstance(ckpt_config, DictConfig)
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
checkpoint=checkpoint,
|
||||
vae_config=ckpt_config,
|
||||
image_size=512,
|
||||
precision=self._torch_dtype,
|
||||
dump_path=output_path,
|
||||
)
|
||||
return vae_model
|
||||
return super()._load_model(config, submodel_type)
|
||||
|
@ -312,6 +312,8 @@ class ModelProbe(object):
|
||||
config_file = (
|
||||
"stable-diffusion/v1-inference.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "stable-diffusion/sd_xl_base.yaml"
|
||||
if base_type is BaseModelType.StableDiffusionXL
|
||||
else "stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
else:
|
||||
|
@ -25,7 +25,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelVariantType,
|
||||
VAEDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||
from invokeai.backend.model_manager.load import ModelCache
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||
HFTestLoraMetadata,
|
||||
@ -89,17 +89,15 @@ def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
|
||||
ram_cache = ModelCache(
|
||||
logger=InvokeAILogger.get_logger(),
|
||||
max_cache_size=mm2_app_config.ram,
|
||||
max_vram_cache_size=mm2_app_config.vram,
|
||||
)
|
||||
convert_cache = ModelConvertCache(mm2_app_config.convert_cache_path)
|
||||
return ModelLoadService(
|
||||
app_config=mm2_app_config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user