model loading and conversion implemented for vaes

This commit is contained in:
Lincoln Stein 2024-02-03 22:55:09 -05:00 committed by psychedelicious
parent 5c2884569e
commit 60aa3d4893
29 changed files with 2382 additions and 237 deletions

View File

@ -8,6 +8,8 @@ from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMe
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache
from invokeai.backend.model_manager.load.model_cache import ModelCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -98,15 +100,26 @@ class ApiDependencies:
) )
model_manager = ModelManagerService(config, logger) model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db) model_record_service = ModelRecordServiceSQL(db=db)
model_loader = AnyModelLoader(
app_config=config,
logger=logger,
ram_cache=ModelCache(
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
),
convert_cache=ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
),
)
model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader)
download_queue_service = DownloadQueueService(event_bus=events) download_queue_service = DownloadQueueService(event_bus=events)
metadata_store = ModelMetadataStore(db=db)
model_install_service = ModelInstallService( model_install_service = ModelInstallService(
app_config=config, app_config=config,
record_store=model_record_service, record_store=model_record_service,
download_queue=download_queue_service, download_queue=download_queue_service,
metadata_store=metadata_store, metadata_store=ModelMetadataStore(db=db),
event_bus=events, event_bus=events,
) )
model_manager = ModelManagerService(config, logger) # TO DO: legacy model manager v1. Remove
names = SimpleNameService() names = SimpleNameService()
performance_statistics = InvocationStatsService() performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor() processor = DefaultInvocationProcessor()

View File

@ -237,6 +237,7 @@ class InvokeAIAppConfig(InvokeAISettings):
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths) autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths) conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths) models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths)
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths) legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths) db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths) outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
@ -262,6 +263,8 @@ class InvokeAIAppConfig(InvokeAISettings):
# CACHE # CACHE
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
convert_cache : float = Field(default=10.0, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache)
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, ) lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache) log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
@ -404,6 +407,11 @@ class InvokeAIAppConfig(InvokeAISettings):
"""Path to the models directory.""" """Path to the models directory."""
return self._resolve(self.models_dir) return self._resolve(self.models_dir)
@property
def models_convert_cache_path(self) -> Path:
"""Path to the converted cache models directory."""
return self._resolve(self.convert_cache_dir)
@property @property
def custom_nodes_path(self) -> Path: def custom_nodes_path(self) -> Path:
"""Path to the custom nodes directory.""" """Path to the custom nodes directory."""
@ -433,15 +441,20 @@ class InvokeAIAppConfig(InvokeAISettings):
return True return True
@property @property
def ram_cache_size(self) -> Union[Literal["auto"], float]: def ram_cache_size(self) -> float:
"""Return the ram cache size using the legacy or modern setting.""" """Return the ram cache size using the legacy or modern setting (GB)."""
return self.max_cache_size or self.ram return self.max_cache_size or self.ram
@property @property
def vram_cache_size(self) -> Union[Literal["auto"], float]: def vram_cache_size(self) -> float:
"""Return the vram cache size using the legacy or modern setting.""" """Return the vram cache size using the legacy or modern setting (GB)."""
return self.max_vram_cache_size or self.vram return self.max_vram_cache_size or self.vram
@property
def convert_cache_size(self) -> float:
"""Return the convert cache size on disk (GB)."""
return self.convert_cache
@property @property
def use_cpu(self) -> bool: def use_cpu(self) -> bool:
"""Return true if the device is set to CPU or the always_use_cpu flag is set.""" """Return true if the device is set to CPU or the always_use_cpu flag is set."""

View File

@ -145,7 +145,7 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102 ) -> str: # noqa D102
model_path = Path(model_path) model_path = Path(model_path)
config = config or {} config = config or {}
if config.get("source") is None: if not config.get("source"):
config["source"] = model_path.resolve().as_posix() config["source"] = model_path.resolve().as_posix()
return self._register(model_path, config) return self._register(model_path, config)
@ -156,7 +156,7 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102 ) -> str: # noqa D102
model_path = Path(model_path) model_path = Path(model_path)
config = config or {} config = config or {}
if config.get("source") is None: if not config.get("source"):
config["source"] = model_path.resolve().as_posix() config["source"] = model_path.resolve().as_posix()
info: AnyModelConfig = self._probe_model(Path(model_path), config) info: AnyModelConfig = self._probe_model(Path(model_path), config)
@ -300,6 +300,7 @@ class ModelInstallService(ModelInstallServiceBase):
job.total_bytes = self._stat_size(job.local_path) job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes job.bytes = job.total_bytes
self._signal_job_running(job) self._signal_job_running(job)
job.config_in["source"] = str(job.source)
if job.inplace: if job.inplace:
key = self.register_path(job.local_path, job.config_in) key = self.register_path(job.local_path, job.config_in)
else: else:

View File

@ -11,7 +11,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType from invokeai.backend.model_manager import LoadedModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
@ -102,6 +102,19 @@ class ModelRecordServiceBase(ABC):
""" """
pass pass
@abstractmethod
def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel_type: For main (pipeline models), the submodel to fetch
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
pass
@property @property
@abstractmethod @abstractmethod
def metadata_store(self) -> ModelMetadataStore: def metadata_store(self) -> ModelMetadataStore:

View File

@ -42,6 +42,7 @@ Typical usage:
import json import json
import sqlite3 import sqlite3
import time
from math import ceil from math import ceil
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import Any, Dict, List, Optional, Set, Tuple, Union
@ -53,8 +54,10 @@ from invokeai.backend.model_manager.config import (
ModelConfigFactory, ModelConfigFactory,
ModelFormat, ModelFormat,
ModelType, ModelType,
SubModelType,
) )
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel
from ..shared.sqlite.sqlite_database import SqliteDatabase from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import ( from .model_records_base import (
@ -69,16 +72,17 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase): class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database.""" """Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase): def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader]=None):
""" """
Initialize a new object from preexisting sqlite3 connection and threading lock objects. Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object :param db: Sqlite connection object
:param lock: threading Lock object :param loader: Initialized model loader object (optional)
""" """
super().__init__() super().__init__()
self._db = db self._db = db
self._cursor = self._db.conn.cursor() self._cursor = db.conn.cursor()
self._loader = loader
@property @property
def db(self) -> SqliteDatabase: def db(self) -> SqliteDatabase:
@ -199,7 +203,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT config FROM model_config SELECT config, strftime('%s',updated_at) FROM model_config
WHERE id=?; WHERE id=?;
""", """,
(key,), (key,),
@ -207,9 +211,24 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
rows = self._cursor.fetchone() rows = self._cursor.fetchone()
if not rows: if not rows:
raise UnknownModelException("model not found") raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0])) model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model return model
def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel_type: For main (pipeline models), the submodel to fetch.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
if not self._loader:
raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader")
model_config = self.get_model(key)
return self._loader.load_model(model_config, submodel_type)
def exists(self, key: str) -> bool: def exists(self, key: str) -> bool:
""" """
Return True if a model with the indicated key exists in the databse. Return True if a model with the indicated key exists in the databse.
@ -265,12 +284,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
f"""--sql f"""--sql
select config FROM model_config select config, strftime('%s',updated_at) FROM model_config
{where}; {where};
""", """,
tuple(bindings), tuple(bindings),
) )
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
return results return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
@ -279,12 +298,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT config FROM model_config SELECT config, strftime('%s',updated_at) FROM model_config
WHERE path=?; WHERE path=?;
""", """,
(str(path),), (str(path),),
) )
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
return results return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]: def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
@ -293,12 +312,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT config FROM model_config SELECT config, strftime('%s',updated_at) FROM model_config
WHERE original_hash=?; WHERE original_hash=?;
""", """,
(hash,), (hash,),
) )
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
return results return results
@property @property

View File

@ -8,6 +8,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -33,6 +34,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_3(app_config=config, logger=logger)) migrator.register_migration(build_migration_3(app_config=config, logger=logger))
migrator.register_migration(build_migration_4()) migrator.register_migration(build_migration_4())
migrator.register_migration(build_migration_5()) migrator.register_migration(build_migration_5())
migrator.register_migration(build_migration_6())
migrator.run_migrations() migrator.run_migrations()
return db return db

View File

@ -0,0 +1,44 @@
import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration6Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._recreate_model_triggers(cursor)
def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None:
"""
Adds the timestamp trigger to the model_config table.
This trigger was inadvertently dropped in earlier migration scripts.
"""
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
def build_migration_6() -> Migration:
"""
Build the migration from database version 5 to 6.
This migration does the following:
- Adds the model_config_updated_at trigger if it does not exist
"""
migration_6 = Migration(
from_version=5,
to_version=6,
callback=Migration6Callback(),
)
return migration_6

View File

@ -98,11 +98,13 @@ class TqdmEventService(EventServiceBase):
super().__init__() super().__init__()
self._bars: Dict[str, tqdm] = {} self._bars: Dict[str, tqdm] = {}
self._last: Dict[str, int] = {} self._last: Dict[str, int] = {}
self._logger = InvokeAILogger.get_logger(__name__)
def dispatch(self, event_name: str, payload: Any) -> None: def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events.""" """Dispatch an event by appending it to self.events."""
data = payload["data"]
source = data["source"]
if payload["event"] == "model_install_downloading": if payload["event"] == "model_install_downloading":
data = payload["data"]
dest = data["local_path"] dest = data["local_path"]
total_bytes = data["total_bytes"] total_bytes = data["total_bytes"]
bytes = data["bytes"] bytes = data["bytes"]
@ -111,7 +113,12 @@ class TqdmEventService(EventServiceBase):
self._last[dest] = 0 self._last[dest] = 0
self._bars[dest].update(bytes - self._last[dest]) self._bars[dest].update(bytes - self._last[dest])
self._last[dest] = bytes self._last[dest] = bytes
elif payload["event"] == "model_install_completed":
self._logger.info(f"{source}: installed successfully.")
elif payload["event"] == "model_install_error":
self._logger.warning(f"{source}: installation failed with error {data['error']}")
elif payload["event"] == "model_install_cancelled":
self._logger.warning(f"{source}: installation cancelled")
class InstallHelper(object): class InstallHelper(object):
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db.""" """Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""

View File

@ -1,6 +1,7 @@
"""Re-export frequently-used symbols from the Model Manager backend.""" """Re-export frequently-used symbols from the Model Manager backend."""
from .config import ( from .config import (
AnyModel,
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
InvalidModelConfigException, InvalidModelConfigException,
@ -14,12 +15,15 @@ from .config import (
) )
from .probe import ModelProbe from .probe import ModelProbe
from .search import ModelSearch from .search import ModelSearch
from .load import LoadedModel
__all__ = [ __all__ = [
"AnyModel",
"AnyModelConfig", "AnyModelConfig",
"BaseModelType", "BaseModelType",
"ModelRepoVariant", "ModelRepoVariant",
"InvalidModelConfigException", "InvalidModelConfigException",
"LoadedModel",
"ModelConfigFactory", "ModelConfigFactory",
"ModelFormat", "ModelFormat",
"ModelProbe", "ModelProbe",

View File

@ -19,12 +19,15 @@ Typical usage:
Validation errors will raise an InvalidModelConfigException error. Validation errors will raise an InvalidModelConfigException error.
""" """
import time
import torch
from enum import Enum from enum import Enum
from typing import Literal, Optional, Type, Union from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from diffusers import ModelMixin
from typing_extensions import Annotated, Any, Dict from typing_extensions import Annotated, Any, Dict
from .onnx_runtime import IAIOnnxRuntimeModel
class InvalidModelConfigException(Exception): class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format.""" """Exception for when config parser doesn't recognized this combination of model type and format."""
@ -127,6 +130,7 @@ class ModelConfigBase(BaseModel):
) # if model is converted or otherwise modified, this will hold updated hash ) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(default=None) description: Optional[str] = Field(default=None)
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None) source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
last_modified: Optional[float] = Field(description="Timestamp for modification time", default_factory=time.time)
model_config = ConfigDict( model_config = ConfigDict(
use_enum_values=False, use_enum_values=False,
@ -280,6 +284,7 @@ AnyModelConfig = Union[
] ]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig) AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel]
# IMPLEMENTATION NOTE: # IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown # The preferred alternative to the above is a discriminated Union as shown
@ -312,6 +317,7 @@ class ModelConfigFactory(object):
model_data: Union[dict, AnyModelConfig], model_data: Union[dict, AnyModelConfig],
key: Optional[str] = None, key: Optional[str] = None,
dest_class: Optional[Type] = None, dest_class: Optional[Type] = None,
timestamp: Optional[float] = None
) -> AnyModelConfig: ) -> AnyModelConfig:
""" """
Return the appropriate config object from raw dict values. Return the appropriate config object from raw dict values.
@ -330,4 +336,6 @@ class ModelConfigFactory(object):
model = AnyModelConfigValidator.validate_python(model_data) model = AnyModelConfigValidator.validate_python(model_data)
if key: if key:
model.key = key model.key = key
if timestamp:
model.last_modified = timestamp
return model return model

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,35 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Init file for the model loader.
"""
from importlib import import_module
from pathlib import Path
from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from .load_base import AnyModelLoader, LoadedModel
from .model_cache.model_cache_default import ModelCache
from .convert_cache.convert_cache_default import ModelConvertCache
# This registers the subclasses that implement loaders of specific model types
loaders = [x.stem for x in Path(Path(__file__).parent,'model_loaders').glob('*.py') if x.stem != '__init__']
for module in loaders:
print(f'module={module}')
import_module(f"{__package__}.model_loaders.{module}")
__all__ = ["AnyModelLoader", "LoadedModel"]
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
app_config = app_config or InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger(config=app_config)
return AnyModelLoader(app_config=app_config,
logger=logger,
ram_cache=ModelCache(logger=logger,
max_cache_size=app_config.ram_cache_size,
max_vram_cache_size=app_config.vram_cache_size
),
convert_cache=ModelConvertCache(app_config.models_convert_cache_path)
)

View File

@ -0,0 +1,4 @@
from .convert_cache_base import ModelConvertCacheBase
from .convert_cache_default import ModelConvertCache
__all__ = ['ModelConvertCacheBase', 'ModelConvertCache']

View File

@ -0,0 +1,28 @@
"""
Disk-based converted model cache.
"""
from abc import ABC, abstractmethod
from pathlib import Path
class ModelConvertCacheBase(ABC):
@property
@abstractmethod
def max_size(self) -> float:
"""Return the maximum size of this cache directory."""
pass
@abstractmethod
def make_room(self, size: float) -> None:
"""
Make sufficient room in the cache directory for a model of max_size.
:param size: Size required (GB)
"""
pass
@abstractmethod
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
pass

View File

@ -0,0 +1,64 @@
"""
Placeholder for convert cache implementation.
"""
from pathlib import Path
import shutil
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util import GIG, directory_size
from .convert_cache_base import ModelConvertCacheBase
class ModelConvertCache(ModelConvertCacheBase):
def __init__(self, cache_path: Path, max_size: float=10.0):
"""Initialize the convert cache with the base directory and a limit on its maximum size (in GBs)."""
if not cache_path.exists():
cache_path.mkdir(parents=True)
self._cache_path = cache_path
self._max_size = max_size
@property
def max_size(self) -> float:
"""Return the maximum size of this cache directory (GB)."""
return self._max_size
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
return self._cache_path / key
def make_room(self, size: float) -> None:
"""
Make sufficient room in the cache directory for a model of max_size.
:param size: Size required (GB)
"""
size_needed = directory_size(self._cache_path) + size
max_size = int(self.max_size) * GIG
logger = InvokeAILogger.get_logger()
if size_needed <= max_size:
return
logger.debug(
f"Convert cache has gotten too large {(size_needed / GIG):4.2f} > {(max_size / GIG):4.2f}G.. Trimming."
)
# For this to work, we make the assumption that the directory contains
# a 'model_index.json', 'unet/config.json' file, or a 'config.json' file at top level.
# This should be true for any diffusers model.
def by_atime(path: Path) -> float:
for config in ["model_index.json", "unet/config.json", "config.json"]:
sentinel = path / config
if sentinel.exists():
return sentinel.stat().st_atime
return 0.0
# sort by last access time - least accessed files will be at the end
lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True)
logger.debug(f"cached models in descending atime order: {lru_models}")
while size_needed > max_size and len(lru_models) > 0:
next_victim = lru_models.pop()
victim_size = directory_size(next_victim)
logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB")
shutil.rmtree(next_victim)
size_needed -= victim_size

View File

@ -16,39 +16,11 @@ from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, Union from typing import Any, Callable, Dict, Optional, Type, Union
import torch
from diffusers import DiffusionPipeline
from injector import inject
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLockerBase
from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
from invokeai.backend.model_manager.ram_cache import ModelCacheBase
AnyModel = Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel]
class ModelLockerBase(ABC):
"""Base class for the model locker used by the loader."""
@abstractmethod
def lock(self) -> None:
"""Lock the contained model and move it into VRAM."""
pass
@abstractmethod
def unlock(self) -> None:
"""Unlock the contained model, and remove it from VRAM."""
pass
@property
@abstractmethod
def model(self) -> AnyModel:
"""Return the model."""
pass
@dataclass @dataclass
class LoadedModel: class LoadedModel:
@ -69,7 +41,7 @@ class LoadedModel:
@property @property
def model(self) -> AnyModel: def model(self) -> AnyModel:
"""Return the model without locking it.""" """Return the model without locking it."""
return self.locker.model() return self.locker.model
class ModelLoaderBase(ABC): class ModelLoaderBase(ABC):
@ -89,9 +61,9 @@ class ModelLoaderBase(ABC):
@abstractmethod @abstractmethod
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
""" """
Return a model given its key. Return a model given its confguration.
Given a model key identified in the model configuration backend, Given a model identified in the model configuration backend,
return a ModelInfo object that can be used to retrieve the model. return a ModelInfo object that can be used to retrieve the model.
:param model_config: Model configuration, as returned by ModelConfigRecordStore :param model_config: Model configuration, as returned by ModelConfigRecordStore
@ -115,34 +87,32 @@ class AnyModelLoader:
# this tracks the loader subclasses # this tracks the loader subclasses
_registry: Dict[str, Type[ModelLoaderBase]] = {} _registry: Dict[str, Type[ModelLoaderBase]] = {}
@inject
def __init__( def __init__(
self, self,
store: ModelRecordServiceBase,
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
logger: Logger, logger: Logger,
ram_cache: ModelCacheBase, ram_cache: ModelCacheBase,
convert_cache: ModelConvertCacheBase, convert_cache: ModelConvertCacheBase,
): ):
"""Store the provided ModelRecordServiceBase and empty the registry.""" """Initialize AnyModelLoader with its dependencies."""
self._store = store
self._app_config = app_config self._app_config = app_config
self._logger = logger self._logger = logger
self._ram_cache = ram_cache self._ram_cache = ram_cache
self._convert_cache = convert_cache self._convert_cache = convert_cache
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: @property
""" def ram_cache(self) -> ModelCacheBase:
Return a model given its key. """Return the RAM cache associated used by the loaders."""
return self._ram_cache
Given a model key identified in the model configuration backend, def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType]=None) -> LoadedModel:
return a ModelInfo object that can be used to retrieve the model. """
Return a model given its configuration.
:param key: model key, as known to the config backend :param key: model key, as known to the config backend
:param submodel_type: an ModelType enum indicating the portion of :param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae) the model to retrieve (e.g. ModelType.Vae)
""" """
model_config = self._store.get_model(key)
implementation = self.__class__.get_implementation( implementation = self.__class__.get_implementation(
base=model_config.base, type=model_config.type, format=model_config.format base=model_config.base, type=model_config.type, format=model_config.format
) )
@ -165,7 +135,7 @@ class AnyModelLoader:
implementation = cls._registry.get(key1) or cls._registry.get(key2) implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation: if not implementation:
raise NotImplementedError( raise NotImplementedError(
"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}" f"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}"
) )
return implementation return implementation
@ -176,18 +146,10 @@ class AnyModelLoader:
"""Define a decorator which registers the subclass of loader.""" """Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
print("Registering class", subclass.__name__) print("DEBUG: Registering class", subclass.__name__)
key = cls._to_registry_key(base, type, format) key = cls._to_registry_key(base, type, format)
cls._registry[key] = subclass cls._registry[key] = subclass
return subclass return subclass
return decorator return decorator
# in _init__.py will call something like
# def configure_loader_dependencies(binder):
# binder.bind(ModelRecordServiceBase, ApiDependencies.invoker.services.model_records, scope=singleton)
# binder.bind(InvokeAIAppConfig, ApiDependencies.invoker.services.configuration, scope=singleton)
# etc
# injector = Injector(configure_loader_dependencies)
# loader = injector.get(ModelFactory)

View File

@ -8,15 +8,14 @@ from typing import Any, Dict, Optional, Tuple
from diffusers import ModelMixin from diffusers import ModelMixin
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from injector import inject
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.model_manager.ram_cache import ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.util.devices import choose_torch_device, torch_dtype from invokeai.backend.util.devices import choose_torch_device, torch_dtype
@ -35,7 +34,6 @@ class ConfigLoader(ConfigMixin):
class ModelLoader(ModelLoaderBase): class ModelLoader(ModelLoaderBase):
"""Default implementation of ModelLoaderBase.""" """Default implementation of ModelLoaderBase."""
@inject # can inject instances of each of the classes in the call signature
def __init__( def __init__(
self, self,
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
@ -87,18 +85,15 @@ class ModelLoader(ModelLoaderBase):
def _convert_if_needed( def _convert_if_needed(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> Path: ) -> Path:
if not self._needs_conversion(config): cache_path: Path = self._convert_cache.cache_path(config.key)
return model_path
if not self._needs_conversion(config, model_path, cache_path):
return cache_path if cache_path.exists() else model_path
self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type)) self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type))
cache_path: Path = self._convert_cache.cache_path(config.key) return self._convert_model(config, model_path, cache_path)
if cache_path.exists():
return cache_path
self._convert_model(model_path, cache_path) def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
return cache_path
def _needs_conversion(self, config: AnyModelConfig) -> bool:
return False return False
def _load_if_needed( def _load_if_needed(
@ -133,7 +128,7 @@ class ModelLoader(ModelLoaderBase):
variant=config.repo_variant if hasattr(config, "repo_variant") else None, variant=config.repo_variant if hasattr(config, "repo_variant") else None,
) )
def _convert_model(self, model_path: Path, cache_path: Path) -> None: def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
raise NotImplementedError raise NotImplementedError
def _load_model( def _load_model(

View File

@ -0,0 +1,5 @@
"""Init file for RamCache."""
from .model_cache_base import ModelCacheBase
from .model_cache_default import ModelCache
_all__ = ['ModelCacheBase', 'ModelCache']

View File

@ -10,34 +10,41 @@ model will be cleared and (re)loaded from disk when next needed.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from logging import Logger from logging import Logger
from typing import Dict, Optional from typing import Dict, Optional, TypeVar, Generic
import torch import torch
from invokeai.backend.model_manager import SubModelType from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase
class ModelLockerBase(ABC):
"""Base class for the model locker used by the loader."""
@abstractmethod
def lock(self) -> AnyModel:
"""Lock the contained model and move it into VRAM."""
pass
@abstractmethod
def unlock(self) -> None:
"""Unlock the contained model, and remove it from VRAM."""
pass
@property
@abstractmethod
def model(self) -> AnyModel:
"""Return the model."""
pass
T = TypeVar("T")
@dataclass @dataclass
class CacheStats(object): class CacheRecord(Generic[T]):
"""Data object to record statistics on cache hits/misses."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
@dataclass
class CacheRecord:
"""Elements of the cache.""" """Elements of the cache."""
key: str key: str
model: AnyModel model: T
size: int size: int
loaded: bool = False
_locks: int = 0 _locks: int = 0
def lock(self) -> None: def lock(self) -> None:
@ -55,7 +62,7 @@ class CacheRecord:
return self._locks > 0 return self._locks > 0
class ModelCacheBase(ABC): class ModelCacheBase(ABC, Generic[T]):
"""Virtual base class for RAM model cache.""" """Virtual base class for RAM model cache."""
@property @property
@ -76,8 +83,14 @@ class ModelCacheBase(ABC):
"""Return true if the cache is configured to lazily offload models in VRAM.""" """Return true if the cache is configured to lazily offload models in VRAM."""
pass pass
@property
@abstractmethod @abstractmethod
def offload_unlocked_models(self) -> None: def max_cache_size(self) -> float:
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@abstractmethod
def offload_unlocked_models(self, size_required: int) -> None:
"""Offload from VRAM any models not actively in use.""" """Offload from VRAM any models not actively in use."""
pass pass
@ -101,7 +114,7 @@ class ModelCacheBase(ABC):
def put( def put(
self, self,
key: str, key: str,
model: AnyModel, model: T,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> None: ) -> None:
"""Store model under key and optional submodel_type.""" """Store model under key and optional submodel_type."""
@ -134,11 +147,6 @@ class ModelCacheBase(ABC):
"""Get the total size of the models currently cached.""" """Get the total size of the models currently cached."""
pass pass
@abstractmethod
def get_stats(self) -> CacheStats:
"""Return cache hit/miss/size statistics."""
pass
@abstractmethod @abstractmethod
def print_cuda_stats(self) -> None: def print_cuda_stats(self) -> None:
"""Log debugging information on CUDA usage.""" """Log debugging information on CUDA usage."""

View File

@ -18,6 +18,7 @@ context. Use like this:
""" """
import gc
import math import math
import time import time
from contextlib import suppress from contextlib import suppress
@ -26,14 +27,14 @@ from typing import Any, Dict, List, Optional
import torch import torch
from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager import SubModelType from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase from invokeai.backend.model_manager.load.load_base import AnyModel
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.model_manager.load.ram_cache.ram_cache_base import CacheRecord, CacheStats, ModelCacheBase
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, ModelCacheBase
from .model_locker import ModelLockerBase, ModelLocker
if choose_torch_device() == torch.device("mps"): if choose_torch_device() == torch.device("mps"):
from torch import mps from torch import mps
@ -52,7 +53,7 @@ GIG = 1073741824
MB = 2**20 MB = 2**20
class ModelCache(ModelCacheBase): class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase.""" """Implementation of ModelCacheBase."""
def __init__( def __init__(
@ -92,62 +93,9 @@ class ModelCache(ModelCacheBase):
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage self._log_memory_usage = log_memory_usage
# used for stats collection self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self.stats = None
self._cached_models: Dict[str, CacheRecord] = {}
self._cache_stack: List[str] = [] self._cache_stack: List[str] = []
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
def __init__(self, cache: ModelCacheBase, cache_entry: CacheRecord):
"""
Initialize the model locker.
:param cache: The ModelCache object
:param cache_entry: The entry in the model cache
"""
self._cache = cache
self._cache_entry = cache_entry
@property
def model(self) -> AnyModel:
"""Return the model without moving it around."""
return self._cache_entry.model
def lock(self) -> Any:
"""Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock()
try:
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models()
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
except Exception:
self._cache_entry.unlock()
raise
return self.model
def unlock(self) -> None:
"""Call upon exit from context."""
if not hasattr(self.model, "to"):
return
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models()
self._cache.print_cuda_stats()
@property @property
def logger(self) -> Logger: def logger(self) -> Logger:
"""Return the logger used by the cache.""" """Return the logger used by the cache."""
@ -168,6 +116,11 @@ class ModelCache(ModelCacheBase):
"""Return the exection device (e.g. "cuda" for VRAM).""" """Return the exection device (e.g. "cuda" for VRAM)."""
return self._execution_device return self._execution_device
@property
def max_cache_size(self) -> float:
"""Return the cap on cache size."""
return self._max_cache_size
def cache_size(self) -> int: def cache_size(self) -> int:
"""Get the total size of the models currently cached.""" """Get the total size of the models currently cached."""
total = 0 total = 0
@ -207,18 +160,18 @@ class ModelCache(ModelCacheBase):
""" """
Retrieve model using key and optional submodel_type. Retrieve model using key and optional submodel_type.
This may return an UnknownModelException if the model is not in the cache. This may return an IndexError if the model is not in the cache.
""" """
key = self._make_cache_key(key, submodel_type) key = self._make_cache_key(key, submodel_type)
if key not in self._cached_models: if key not in self._cached_models:
raise UnknownModelException raise IndexError(f"The model with key {key} is not in the cache.")
# this moves the entry to the top (right end) of the stack # this moves the entry to the top (right end) of the stack
with suppress(Exception): with suppress(Exception):
self._cache_stack.remove(key) self._cache_stack.remove(key)
self._cache_stack.append(key) self._cache_stack.append(key)
cache_entry = self._cached_models[key] cache_entry = self._cached_models[key]
return self.ModelLocker( return ModelLocker(
cache=self, cache=self,
cache_entry=cache_entry, cache_entry=cache_entry,
) )
@ -234,19 +187,19 @@ class ModelCache(ModelCacheBase):
else: else:
return model_key return model_key
def offload_unlocked_models(self) -> None: def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM.""" """Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG reserved = self._max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated() vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved: if vram_in_use <= reserved:
break break
if not cache_entry.locked: if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device) self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM now available for models; max allowed={(reserved/GIG):.2f}GB")
torch.cuda.empty_cache() torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"): if choose_torch_device() == torch.device("mps"):
@ -305,28 +258,111 @@ class ModelCache(ModelCacheBase):
def print_cuda_stats(self) -> None: def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics.""" """Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % self.cache_size() ram = "%4.2fG" % (self.cache_size() / GIG)
cached_models = 0 in_ram_models = 0
loaded_models = 0 in_vram_models = 0
locked_models = 0 locked_in_vram_models = 0
for cache_record in self._cached_models.values(): for cache_record in self._cached_models.values():
cached_models += 1
assert hasattr(cache_record.model, "device") assert hasattr(cache_record.model, "device")
if cache_record.model.device is self.storage_device: if cache_record.model.device == self.storage_device:
loaded_models += 1 in_ram_models += 1
else:
in_vram_models += 1
if cache_record.locked: if cache_record.locked:
locked_models += 1 locked_in_vram_models += 1
self.logger.debug( self.logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ =" f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
f" {cached_models}/{loaded_models}/{locked_models}" f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
) )
def get_stats(self) -> CacheStats: def make_room(self, model_size: int) -> None:
"""Return cache hit/miss/size statistics."""
raise NotImplementedError
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size.""" """Make enough room in the cache to accommodate a new model of indicated size."""
raise NotImplementedError # calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = self.cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB"
)
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
pos = 0
models_cleared = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
f" refs: {refs}"
)
# Expected refs:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug(
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1
del self._cache_stack[pos]
del self._cached_models[model_key]
del cache_entry
else:
pos += 1
if models_cleared > 0:
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
# is high even if no garbage gets collected.)
#
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
# - If models had to be cleared, it's a signal that we are close to our memory limit.
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
# collected.
#
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
# immediately when their reference count hits 0.
gc.collect()
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")

View File

@ -0,0 +1,59 @@
"""
Base class and implementation of a class that moves models in and out of VRAM.
"""
from abc import ABC, abstractmethod
from invokeai.backend.model_manager import AnyModel
from .model_cache_base import ModelLockerBase, ModelCacheBase, CacheRecord
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]):
"""
Initialize the model locker.
:param cache: The ModelCache object
:param cache_entry: The entry in the model cache
"""
self._cache = cache
self._cache_entry = cache_entry
@property
def model(self) -> AnyModel:
"""Return the model without moving it around."""
return self._cache_entry.model
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock()
try:
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
except Exception:
self._cache_entry.unlock()
raise
return self.model
def unlock(self) -> None:
"""Call upon exit from context."""
if not hasattr(self.model, "to"):
return
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.print_cuda_stats()

View File

@ -0,0 +1,3 @@
"""
Init file for model_loaders.
"""

View File

@ -0,0 +1,83 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for VAE model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
import torch
import safetensors
from omegaconf import OmegaConf, DictConfig
from invokeai.backend.util.devices import torch_dtype
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
class VaeDiffusersModel(ModelLoader):
"""Class to load VAE models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise Exception("There are no submodels in VAEs")
vae_class = self._get_hf_load_class(model_path)
variant = model_variant.value if model_variant else None
result: AnyModel = vae_class.from_pretrained(
model_path, torch_dtype=self._torch_dtype, variant=variant
) # type: ignore
return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
print(f'DEBUG: last_modified={config.last_modified}')
print(f'DEBUG: cache_path={(dest_path / "config.json").stat().st_mtime}')
print(f'DEBUG: model_path={model_path.stat().st_mtime}')
if config.format != ModelFormat.Checkpoint:
return False
elif dest_path.exists() \
and (dest_path / "config.json").stat().st_mtime >= config.last_modified \
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime:
return False
else:
return True
def _convert_model(self,
config: AnyModelConfig,
weights_path: Path,
output_path: Path
) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
else:
config_file = 'v1-inference.yaml' if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
if weights_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
else:
checkpoint = torch.load(weights_path, map_location="cpu")
dtype = torch_dtype()
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
assert isinstance(ckpt_config, DictConfig)
print(f'DEBUG: CONVERTIGN')
vae_model = convert_ldm_vae_to_diffusers(
checkpoint=checkpoint,
vae_config=ckpt_config,
image_size=512,
)
vae_model.to(dtype) # set precision appropriately
vae_model.save_pretrained(output_path, safe_serialization=True, torch_dtype=dtype)
return output_path

View File

@ -48,6 +48,9 @@ def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int: def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
"""Estimate the size of a model on disk in bytes.""" """Estimate the size of a model on disk in bytes."""
if model_path.is_file():
return model_path.stat().st_size
if subfolder is not None: if subfolder is not None:
model_path = model_path / subfolder model_path = model_path / subfolder

View File

@ -1,31 +0,0 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for VAE model loading in InvokeAI."""
from pathlib import Path
from typing import Dict, Optional
import torch
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
class VaeDiffusersModel(ModelLoader):
"""Class to load VAE models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> Dict[str, torch.Tensor]:
if submodel_type is not None:
raise Exception("There are no submodels in VAEs")
vae_class = self._get_hf_load_class(model_path)
variant = model_variant.value if model_variant else ""
result: Dict[str, torch.Tensor] = vae_class.from_pretrained(
model_path, torch_dtype=self._torch_dtype, variant=variant
) # type: ignore
return result

View File

@ -12,6 +12,14 @@ from .devices import ( # noqa: F401
torch_dtype, torch_dtype,
) )
from .logging import InvokeAILogger from .logging import InvokeAILogger
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401 from .util import ( # TO DO: Clean this up; remove the unused symbols
GIG,
Chdir,
ask_user, # noqa
directory_size,
download_with_resume,
instantiate_from_config, # noqa
url_attachment_name, # noqa
)
__all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"] __all__ = ["GIG", "directory_size","Chdir", "download_with_resume", "InvokeAILogger", "choose_precision", "choose_torch_device"]

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import nullcontext from contextlib import nullcontext
from typing import Union from typing import Union, Optional
import torch import torch
from torch import autocast from torch import autocast
@ -43,7 +43,8 @@ def choose_precision(device: torch.device) -> str:
return "float32" return "float32"
def torch_dtype(device: torch.device) -> torch.dtype: def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
device = device or choose_torch_device()
precision = choose_precision(device) precision = choose_precision(device)
if precision == "float16": if precision == "float16":
return torch.float16 return torch.float16

View File

@ -24,6 +24,20 @@ import invokeai.backend.util.logging as logger
from .devices import torch_dtype from .devices import torch_dtype
# actual size of a gig
GIG = 1073741824
def directory_size(directory: Path) -> int:
"""
Return the aggregate size of all files in a directory (bytes).
"""
sum = 0
for root, dirs, files in os.walk(directory):
for f in files:
sum += Path(root, f).stat().st_size
for d in dirs:
sum += Path(root, d).stat().st_size
return sum
def log_txt_as_img(wh, xc, size=10): def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height) # wh a tuple of (width, height)