model loading and conversion implemented for vaes

This commit is contained in:
Lincoln Stein 2024-02-03 22:55:09 -05:00 committed by Brandon Rising
parent 231c12fd1e
commit e242fe41f2
29 changed files with 2382 additions and 237 deletions

View File

@ -8,6 +8,8 @@ from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMe
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache
from invokeai.backend.model_manager.load.model_cache import ModelCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.logging import InvokeAILogger
@ -98,15 +100,26 @@ class ApiDependencies:
)
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
model_loader = AnyModelLoader(
app_config=config,
logger=logger,
ram_cache=ModelCache(
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
),
convert_cache=ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
),
)
model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader)
download_queue_service = DownloadQueueService(event_bus=events)
metadata_store = ModelMetadataStore(db=db)
model_install_service = ModelInstallService(
app_config=config,
record_store=model_record_service,
download_queue=download_queue_service,
metadata_store=metadata_store,
metadata_store=ModelMetadataStore(db=db),
event_bus=events,
)
model_manager = ModelManagerService(config, logger) # TO DO: legacy model manager v1. Remove
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()

View File

@ -237,6 +237,7 @@ class InvokeAIAppConfig(InvokeAISettings):
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths)
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
@ -262,6 +263,8 @@ class InvokeAIAppConfig(InvokeAISettings):
# CACHE
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
convert_cache : float = Field(default=10.0, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache)
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
@ -404,6 +407,11 @@ class InvokeAIAppConfig(InvokeAISettings):
"""Path to the models directory."""
return self._resolve(self.models_dir)
@property
def models_convert_cache_path(self) -> Path:
"""Path to the converted cache models directory."""
return self._resolve(self.convert_cache_dir)
@property
def custom_nodes_path(self) -> Path:
"""Path to the custom nodes directory."""
@ -433,15 +441,20 @@ class InvokeAIAppConfig(InvokeAISettings):
return True
@property
def ram_cache_size(self) -> Union[Literal["auto"], float]:
"""Return the ram cache size using the legacy or modern setting."""
def ram_cache_size(self) -> float:
"""Return the ram cache size using the legacy or modern setting (GB)."""
return self.max_cache_size or self.ram
@property
def vram_cache_size(self) -> Union[Literal["auto"], float]:
"""Return the vram cache size using the legacy or modern setting."""
def vram_cache_size(self) -> float:
"""Return the vram cache size using the legacy or modern setting (GB)."""
return self.max_vram_cache_size or self.vram
@property
def convert_cache_size(self) -> float:
"""Return the convert cache size on disk (GB)."""
return self.convert_cache
@property
def use_cpu(self) -> bool:
"""Return true if the device is set to CPU or the always_use_cpu flag is set."""

View File

@ -145,7 +145,7 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if config.get("source") is None:
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
return self._register(model_path, config)
@ -156,7 +156,7 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if config.get("source") is None:
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
info: AnyModelConfig = self._probe_model(Path(model_path), config)
@ -300,6 +300,7 @@ class ModelInstallService(ModelInstallServiceBase):
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:

View File

@ -11,7 +11,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager import LoadedModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
@ -102,6 +102,19 @@ class ModelRecordServiceBase(ABC):
"""
pass
@abstractmethod
def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel_type: For main (pipeline models), the submodel to fetch
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
pass
@property
@abstractmethod
def metadata_store(self) -> ModelMetadataStore:

View File

@ -42,6 +42,7 @@ Typical usage:
import json
import sqlite3
import time
from math import ceil
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
@ -53,8 +54,10 @@ from invokeai.backend.model_manager.config import (
ModelConfigFactory,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import (
@ -69,16 +72,17 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase):
def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader]=None):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
:param db: Sqlite connection object
:param loader: Initialized model loader object (optional)
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
self._cursor = db.conn.cursor()
self._loader = loader
@property
def db(self) -> SqliteDatabase:
@ -199,7 +203,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
SELECT config, strftime('%s',updated_at) FROM model_config
WHERE id=?;
""",
(key,),
@ -207,9 +211,24 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]))
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel_type: For main (pipeline models), the submodel to fetch.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
if not self._loader:
raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader")
model_config = self.get_model(key)
return self._loader.load_model(model_config, submodel_type)
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
@ -265,12 +284,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
f"""--sql
select config FROM model_config
select config, strftime('%s',updated_at) FROM model_config
{where};
""",
tuple(bindings),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
@ -279,12 +298,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
SELECT config, strftime('%s',updated_at) FROM model_config
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
@ -293,12 +312,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
SELECT config, strftime('%s',updated_at) FROM model_config
WHERE original_hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
return results
@property

View File

@ -8,6 +8,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -33,6 +34,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_3(app_config=config, logger=logger))
migrator.register_migration(build_migration_4())
migrator.register_migration(build_migration_5())
migrator.register_migration(build_migration_6())
migrator.run_migrations()
return db

View File

@ -0,0 +1,44 @@
import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration6Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._recreate_model_triggers(cursor)
def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None:
"""
Adds the timestamp trigger to the model_config table.
This trigger was inadvertently dropped in earlier migration scripts.
"""
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
def build_migration_6() -> Migration:
"""
Build the migration from database version 5 to 6.
This migration does the following:
- Adds the model_config_updated_at trigger if it does not exist
"""
migration_6 = Migration(
from_version=5,
to_version=6,
callback=Migration6Callback(),
)
return migration_6

View File

@ -98,11 +98,13 @@ class TqdmEventService(EventServiceBase):
super().__init__()
self._bars: Dict[str, tqdm] = {}
self._last: Dict[str, int] = {}
self._logger = InvokeAILogger.get_logger(__name__)
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
data = payload["data"]
source = data["source"]
if payload["event"] == "model_install_downloading":
data = payload["data"]
dest = data["local_path"]
total_bytes = data["total_bytes"]
bytes = data["bytes"]
@ -111,7 +113,12 @@ class TqdmEventService(EventServiceBase):
self._last[dest] = 0
self._bars[dest].update(bytes - self._last[dest])
self._last[dest] = bytes
elif payload["event"] == "model_install_completed":
self._logger.info(f"{source}: installed successfully.")
elif payload["event"] == "model_install_error":
self._logger.warning(f"{source}: installation failed with error {data['error']}")
elif payload["event"] == "model_install_cancelled":
self._logger.warning(f"{source}: installation cancelled")
class InstallHelper(object):
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""

View File

@ -1,6 +1,7 @@
"""Re-export frequently-used symbols from the Model Manager backend."""
from .config import (
AnyModel,
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
@ -14,12 +15,15 @@ from .config import (
)
from .probe import ModelProbe
from .search import ModelSearch
from .load import LoadedModel
__all__ = [
"AnyModel",
"AnyModelConfig",
"BaseModelType",
"ModelRepoVariant",
"InvalidModelConfigException",
"LoadedModel",
"ModelConfigFactory",
"ModelFormat",
"ModelProbe",

View File

@ -19,12 +19,15 @@ Typical usage:
Validation errors will raise an InvalidModelConfigException error.
"""
import time
import torch
from enum import Enum
from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from diffusers import ModelMixin
from typing_extensions import Annotated, Any, Dict
from .onnx_runtime import IAIOnnxRuntimeModel
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
@ -127,6 +130,7 @@ class ModelConfigBase(BaseModel):
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(default=None)
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
last_modified: Optional[float] = Field(description="Timestamp for modification time", default_factory=time.time)
model_config = ConfigDict(
use_enum_values=False,
@ -280,6 +284,7 @@ AnyModelConfig = Union[
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel]
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
@ -312,6 +317,7 @@ class ModelConfigFactory(object):
model_data: Union[dict, AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type] = None,
timestamp: Optional[float] = None
) -> AnyModelConfig:
"""
Return the appropriate config object from raw dict values.
@ -330,4 +336,6 @@ class ModelConfigFactory(object):
model = AnyModelConfigValidator.validate_python(model_data)
if key:
model.key = key
if timestamp:
model.last_modified = timestamp
return model

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,35 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Init file for the model loader.
"""
from importlib import import_module
from pathlib import Path
from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from .load_base import AnyModelLoader, LoadedModel
from .model_cache.model_cache_default import ModelCache
from .convert_cache.convert_cache_default import ModelConvertCache
# This registers the subclasses that implement loaders of specific model types
loaders = [x.stem for x in Path(Path(__file__).parent,'model_loaders').glob('*.py') if x.stem != '__init__']
for module in loaders:
print(f'module={module}')
import_module(f"{__package__}.model_loaders.{module}")
__all__ = ["AnyModelLoader", "LoadedModel"]
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
app_config = app_config or InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger(config=app_config)
return AnyModelLoader(app_config=app_config,
logger=logger,
ram_cache=ModelCache(logger=logger,
max_cache_size=app_config.ram_cache_size,
max_vram_cache_size=app_config.vram_cache_size
),
convert_cache=ModelConvertCache(app_config.models_convert_cache_path)
)

View File

@ -0,0 +1,4 @@
from .convert_cache_base import ModelConvertCacheBase
from .convert_cache_default import ModelConvertCache
__all__ = ['ModelConvertCacheBase', 'ModelConvertCache']

View File

@ -0,0 +1,28 @@
"""
Disk-based converted model cache.
"""
from abc import ABC, abstractmethod
from pathlib import Path
class ModelConvertCacheBase(ABC):
@property
@abstractmethod
def max_size(self) -> float:
"""Return the maximum size of this cache directory."""
pass
@abstractmethod
def make_room(self, size: float) -> None:
"""
Make sufficient room in the cache directory for a model of max_size.
:param size: Size required (GB)
"""
pass
@abstractmethod
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
pass

View File

@ -0,0 +1,64 @@
"""
Placeholder for convert cache implementation.
"""
from pathlib import Path
import shutil
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util import GIG, directory_size
from .convert_cache_base import ModelConvertCacheBase
class ModelConvertCache(ModelConvertCacheBase):
def __init__(self, cache_path: Path, max_size: float=10.0):
"""Initialize the convert cache with the base directory and a limit on its maximum size (in GBs)."""
if not cache_path.exists():
cache_path.mkdir(parents=True)
self._cache_path = cache_path
self._max_size = max_size
@property
def max_size(self) -> float:
"""Return the maximum size of this cache directory (GB)."""
return self._max_size
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
return self._cache_path / key
def make_room(self, size: float) -> None:
"""
Make sufficient room in the cache directory for a model of max_size.
:param size: Size required (GB)
"""
size_needed = directory_size(self._cache_path) + size
max_size = int(self.max_size) * GIG
logger = InvokeAILogger.get_logger()
if size_needed <= max_size:
return
logger.debug(
f"Convert cache has gotten too large {(size_needed / GIG):4.2f} > {(max_size / GIG):4.2f}G.. Trimming."
)
# For this to work, we make the assumption that the directory contains
# a 'model_index.json', 'unet/config.json' file, or a 'config.json' file at top level.
# This should be true for any diffusers model.
def by_atime(path: Path) -> float:
for config in ["model_index.json", "unet/config.json", "config.json"]:
sentinel = path / config
if sentinel.exists():
return sentinel.stat().st_atime
return 0.0
# sort by last access time - least accessed files will be at the end
lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True)
logger.debug(f"cached models in descending atime order: {lru_models}")
while size_needed > max_size and len(lru_models) > 0:
next_victim = lru_models.pop()
victim_size = directory_size(next_victim)
logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB")
shutil.rmtree(next_victim)
size_needed -= victim_size

View File

@ -16,39 +16,11 @@ from logging import Logger
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, Union
import torch
from diffusers import DiffusionPipeline
from injector import inject
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.model_manager.ram_cache import ModelCacheBase
AnyModel = Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel]
class ModelLockerBase(ABC):
"""Base class for the model locker used by the loader."""
@abstractmethod
def lock(self) -> None:
"""Lock the contained model and move it into VRAM."""
pass
@abstractmethod
def unlock(self) -> None:
"""Unlock the contained model, and remove it from VRAM."""
pass
@property
@abstractmethod
def model(self) -> AnyModel:
"""Return the model."""
pass
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLockerBase
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
@dataclass
class LoadedModel:
@ -69,7 +41,7 @@ class LoadedModel:
@property
def model(self) -> AnyModel:
"""Return the model without locking it."""
return self.locker.model()
return self.locker.model
class ModelLoaderBase(ABC):
@ -89,9 +61,9 @@ class ModelLoaderBase(ABC):
@abstractmethod
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its key.
Return a model given its confguration.
Given a model key identified in the model configuration backend,
Given a model identified in the model configuration backend,
return a ModelInfo object that can be used to retrieve the model.
:param model_config: Model configuration, as returned by ModelConfigRecordStore
@ -115,34 +87,32 @@ class AnyModelLoader:
# this tracks the loader subclasses
_registry: Dict[str, Type[ModelLoaderBase]] = {}
@inject
def __init__(
self,
store: ModelRecordServiceBase,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase,
convert_cache: ModelConvertCacheBase,
):
"""Store the provided ModelRecordServiceBase and empty the registry."""
self._store = store
"""Initialize AnyModelLoader with its dependencies."""
self._app_config = app_config
self._logger = logger
self._ram_cache = ram_cache
self._convert_cache = convert_cache
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its key.
@property
def ram_cache(self) -> ModelCacheBase:
"""Return the RAM cache associated used by the loaders."""
return self._ram_cache
Given a model key identified in the model configuration backend,
return a ModelInfo object that can be used to retrieve the model.
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType]=None) -> LoadedModel:
"""
Return a model given its configuration.
:param key: model key, as known to the config backend
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
model_config = self._store.get_model(key)
implementation = self.__class__.get_implementation(
base=model_config.base, type=model_config.type, format=model_config.format
)
@ -165,7 +135,7 @@ class AnyModelLoader:
implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation:
raise NotImplementedError(
"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}"
f"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}"
)
return implementation
@ -176,18 +146,10 @@ class AnyModelLoader:
"""Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
print("Registering class", subclass.__name__)
print("DEBUG: Registering class", subclass.__name__)
key = cls._to_registry_key(base, type, format)
cls._registry[key] = subclass
return subclass
return decorator
# in _init__.py will call something like
# def configure_loader_dependencies(binder):
# binder.bind(ModelRecordServiceBase, ApiDependencies.invoker.services.model_records, scope=singleton)
# binder.bind(InvokeAIAppConfig, ApiDependencies.invoker.services.configuration, scope=singleton)
# etc
# injector = Injector(configure_loader_dependencies)
# loader = injector.get(ModelFactory)

View File

@ -8,15 +8,14 @@ from typing import Any, Dict, Optional, Tuple
from diffusers import ModelMixin
from diffusers.configuration_utils import ConfigMixin
from injector import inject
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.model_manager.ram_cache import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
@ -35,7 +34,6 @@ class ConfigLoader(ConfigMixin):
class ModelLoader(ModelLoaderBase):
"""Default implementation of ModelLoaderBase."""
@inject # can inject instances of each of the classes in the call signature
def __init__(
self,
app_config: InvokeAIAppConfig,
@ -87,18 +85,15 @@ class ModelLoader(ModelLoaderBase):
def _convert_if_needed(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> Path:
if not self._needs_conversion(config):
return model_path
cache_path: Path = self._convert_cache.cache_path(config.key)
if not self._needs_conversion(config, model_path, cache_path):
return cache_path if cache_path.exists() else model_path
self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type))
cache_path: Path = self._convert_cache.cache_path(config.key)
if cache_path.exists():
return cache_path
return self._convert_model(config, model_path, cache_path)
self._convert_model(model_path, cache_path)
return cache_path
def _needs_conversion(self, config: AnyModelConfig) -> bool:
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
return False
def _load_if_needed(
@ -133,7 +128,7 @@ class ModelLoader(ModelLoaderBase):
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
)
def _convert_model(self, model_path: Path, cache_path: Path) -> None:
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
raise NotImplementedError
def _load_model(

View File

@ -0,0 +1,5 @@
"""Init file for RamCache."""
from .model_cache_base import ModelCacheBase
from .model_cache_default import ModelCache
_all__ = ['ModelCacheBase', 'ModelCache']

View File

@ -10,34 +10,41 @@ model will be cleared and (re)loaded from disk when next needed.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from logging import Logger
from typing import Dict, Optional
from typing import Dict, Optional, TypeVar, Generic
import torch
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase
from invokeai.backend.model_manager import AnyModel, SubModelType
class ModelLockerBase(ABC):
"""Base class for the model locker used by the loader."""
@abstractmethod
def lock(self) -> AnyModel:
"""Lock the contained model and move it into VRAM."""
pass
@abstractmethod
def unlock(self) -> None:
"""Unlock the contained model, and remove it from VRAM."""
pass
@property
@abstractmethod
def model(self) -> AnyModel:
"""Return the model."""
pass
T = TypeVar("T")
@dataclass
class CacheStats(object):
"""Data object to record statistics on cache hits/misses."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
@dataclass
class CacheRecord:
class CacheRecord(Generic[T]):
"""Elements of the cache."""
key: str
model: AnyModel
model: T
size: int
loaded: bool = False
_locks: int = 0
def lock(self) -> None:
@ -55,7 +62,7 @@ class CacheRecord:
return self._locks > 0
class ModelCacheBase(ABC):
class ModelCacheBase(ABC, Generic[T]):
"""Virtual base class for RAM model cache."""
@property
@ -76,8 +83,14 @@ class ModelCacheBase(ABC):
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@property
@abstractmethod
def offload_unlocked_models(self) -> None:
def max_cache_size(self) -> float:
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@abstractmethod
def offload_unlocked_models(self, size_required: int) -> None:
"""Offload from VRAM any models not actively in use."""
pass
@ -101,7 +114,7 @@ class ModelCacheBase(ABC):
def put(
self,
key: str,
model: AnyModel,
model: T,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
@ -134,11 +147,6 @@ class ModelCacheBase(ABC):
"""Get the total size of the models currently cached."""
pass
@abstractmethod
def get_stats(self) -> CacheStats:
"""Return cache hit/miss/size statistics."""
pass
@abstractmethod
def print_cuda_stats(self) -> None:
"""Log debugging information on CUDA usage."""

View File

@ -18,6 +18,7 @@ context. Use like this:
"""
import gc
import math
import time
from contextlib import suppress
@ -26,14 +27,14 @@ from typing import Any, Dict, List, Optional
import torch
from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase
from invokeai.backend.model_manager.load.load_base import AnyModel
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.model_manager.load.ram_cache.ram_cache_base import CacheRecord, CacheStats, ModelCacheBase
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, ModelCacheBase
from .model_locker import ModelLockerBase, ModelLocker
if choose_torch_device() == torch.device("mps"):
from torch import mps
@ -52,7 +53,7 @@ GIG = 1073741824
MB = 2**20
class ModelCache(ModelCacheBase):
class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase."""
def __init__(
@ -92,62 +93,9 @@ class ModelCache(ModelCacheBase):
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
# used for stats collection
self.stats = None
self._cached_models: Dict[str, CacheRecord] = {}
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
def __init__(self, cache: ModelCacheBase, cache_entry: CacheRecord):
"""
Initialize the model locker.
:param cache: The ModelCache object
:param cache_entry: The entry in the model cache
"""
self._cache = cache
self._cache_entry = cache_entry
@property
def model(self) -> AnyModel:
"""Return the model without moving it around."""
return self._cache_entry.model
def lock(self) -> Any:
"""Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock()
try:
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models()
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
except Exception:
self._cache_entry.unlock()
raise
return self.model
def unlock(self) -> None:
"""Call upon exit from context."""
if not hasattr(self.model, "to"):
return
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models()
self._cache.print_cuda_stats()
@property
def logger(self) -> Logger:
"""Return the logger used by the cache."""
@ -168,6 +116,11 @@ class ModelCache(ModelCacheBase):
"""Return the exection device (e.g. "cuda" for VRAM)."""
return self._execution_device
@property
def max_cache_size(self) -> float:
"""Return the cap on cache size."""
return self._max_cache_size
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""
total = 0
@ -207,18 +160,18 @@ class ModelCache(ModelCacheBase):
"""
Retrieve model using key and optional submodel_type.
This may return an UnknownModelException if the model is not in the cache.
This may return an IndexError if the model is not in the cache.
"""
key = self._make_cache_key(key, submodel_type)
if key not in self._cached_models:
raise UnknownModelException
raise IndexError(f"The model with key {key} is not in the cache.")
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
cache_entry = self._cached_models[key]
return self.ModelLocker(
return ModelLocker(
cache=self,
cache_entry=cache_entry,
)
@ -234,19 +187,19 @@ class ModelCache(ModelCacheBase):
else:
return model_key
def offload_unlocked_models(self) -> None:
def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM now available for models; max allowed={(reserved/GIG):.2f}GB")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
@ -305,28 +258,111 @@ class ModelCache(ModelCacheBase):
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % self.cache_size()
ram = "%4.2fG" % (self.cache_size() / GIG)
cached_models = 0
loaded_models = 0
locked_models = 0
in_ram_models = 0
in_vram_models = 0
locked_in_vram_models = 0
for cache_record in self._cached_models.values():
cached_models += 1
assert hasattr(cache_record.model, "device")
if cache_record.model.device is self.storage_device:
loaded_models += 1
if cache_record.model.device == self.storage_device:
in_ram_models += 1
else:
in_vram_models += 1
if cache_record.locked:
locked_models += 1
locked_in_vram_models += 1
self.logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ ="
f" {cached_models}/{loaded_models}/{locked_models}"
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
)
def get_stats(self) -> CacheStats:
"""Return cache hit/miss/size statistics."""
raise NotImplementedError
def make_room(self, size: int) -> None:
def make_room(self, model_size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
raise NotImplementedError
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = self.cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB"
)
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
pos = 0
models_cleared = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
f" refs: {refs}"
)
# Expected refs:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug(
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1
del self._cache_stack[pos]
del self._cached_models[model_key]
del cache_entry
else:
pos += 1
if models_cleared > 0:
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
# is high even if no garbage gets collected.)
#
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
# - If models had to be cleared, it's a signal that we are close to our memory limit.
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
# collected.
#
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
# immediately when their reference count hits 0.
gc.collect()
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")

View File

@ -0,0 +1,59 @@
"""
Base class and implementation of a class that moves models in and out of VRAM.
"""
from abc import ABC, abstractmethod
from invokeai.backend.model_manager import AnyModel
from .model_cache_base import ModelLockerBase, ModelCacheBase, CacheRecord
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]):
"""
Initialize the model locker.
:param cache: The ModelCache object
:param cache_entry: The entry in the model cache
"""
self._cache = cache
self._cache_entry = cache_entry
@property
def model(self) -> AnyModel:
"""Return the model without moving it around."""
return self._cache_entry.model
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock()
try:
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
except Exception:
self._cache_entry.unlock()
raise
return self.model
def unlock(self) -> None:
"""Call upon exit from context."""
if not hasattr(self.model, "to"):
return
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.print_cuda_stats()

View File

@ -0,0 +1,3 @@
"""
Init file for model_loaders.
"""

View File

@ -0,0 +1,83 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for VAE model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
import torch
import safetensors
from omegaconf import OmegaConf, DictConfig
from invokeai.backend.util.devices import torch_dtype
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
class VaeDiffusersModel(ModelLoader):
"""Class to load VAE models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise Exception("There are no submodels in VAEs")
vae_class = self._get_hf_load_class(model_path)
variant = model_variant.value if model_variant else None
result: AnyModel = vae_class.from_pretrained(
model_path, torch_dtype=self._torch_dtype, variant=variant
) # type: ignore
return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
print(f'DEBUG: last_modified={config.last_modified}')
print(f'DEBUG: cache_path={(dest_path / "config.json").stat().st_mtime}')
print(f'DEBUG: model_path={model_path.stat().st_mtime}')
if config.format != ModelFormat.Checkpoint:
return False
elif dest_path.exists() \
and (dest_path / "config.json").stat().st_mtime >= config.last_modified \
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime:
return False
else:
return True
def _convert_model(self,
config: AnyModelConfig,
weights_path: Path,
output_path: Path
) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
else:
config_file = 'v1-inference.yaml' if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
if weights_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
else:
checkpoint = torch.load(weights_path, map_location="cpu")
dtype = torch_dtype()
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
assert isinstance(ckpt_config, DictConfig)
print(f'DEBUG: CONVERTIGN')
vae_model = convert_ldm_vae_to_diffusers(
checkpoint=checkpoint,
vae_config=ckpt_config,
image_size=512,
)
vae_model.to(dtype) # set precision appropriately
vae_model.save_pretrained(output_path, safe_serialization=True, torch_dtype=dtype)
return output_path

View File

@ -48,6 +48,9 @@ def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
"""Estimate the size of a model on disk in bytes."""
if model_path.is_file():
return model_path.stat().st_size
if subfolder is not None:
model_path = model_path / subfolder

View File

@ -1,31 +0,0 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for VAE model loading in InvokeAI."""
from pathlib import Path
from typing import Dict, Optional
import torch
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
class VaeDiffusersModel(ModelLoader):
"""Class to load VAE models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> Dict[str, torch.Tensor]:
if submodel_type is not None:
raise Exception("There are no submodels in VAEs")
vae_class = self._get_hf_load_class(model_path)
variant = model_variant.value if model_variant else ""
result: Dict[str, torch.Tensor] = vae_class.from_pretrained(
model_path, torch_dtype=self._torch_dtype, variant=variant
) # type: ignore
return result

View File

@ -12,6 +12,14 @@ from .devices import ( # noqa: F401
torch_dtype,
)
from .logging import InvokeAILogger
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
from .util import ( # TO DO: Clean this up; remove the unused symbols
GIG,
Chdir,
ask_user, # noqa
directory_size,
download_with_resume,
instantiate_from_config, # noqa
url_attachment_name, # noqa
)
__all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"]
__all__ = ["GIG", "directory_size","Chdir", "download_with_resume", "InvokeAILogger", "choose_precision", "choose_torch_device"]

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from contextlib import nullcontext
from typing import Union
from typing import Union, Optional
import torch
from torch import autocast
@ -43,7 +43,8 @@ def choose_precision(device: torch.device) -> str:
return "float32"
def torch_dtype(device: torch.device) -> torch.dtype:
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
device = device or choose_torch_device()
precision = choose_precision(device)
if precision == "float16":
return torch.float16

View File

@ -24,6 +24,20 @@ import invokeai.backend.util.logging as logger
from .devices import torch_dtype
# actual size of a gig
GIG = 1073741824
def directory_size(directory: Path) -> int:
"""
Return the aggregate size of all files in a directory (bytes).
"""
sum = 0
for root, dirs, files in os.walk(directory):
for f in files:
sum += Path(root, f).stat().st_size
for d in dirs:
sum += Path(root, d).stat().st_size
return sum
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)