model loading and conversion implemented for vaes

This commit is contained in:
Lincoln Stein
2024-02-03 22:55:09 -05:00
committed by psychedelicious
parent 5c2884569e
commit 60aa3d4893
29 changed files with 2382 additions and 237 deletions

View File

@ -8,6 +8,8 @@ from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMe
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache
from invokeai.backend.model_manager.load.model_cache import ModelCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.logging import InvokeAILogger
@ -98,15 +100,26 @@ class ApiDependencies:
)
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
model_loader = AnyModelLoader(
app_config=config,
logger=logger,
ram_cache=ModelCache(
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
),
convert_cache=ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
),
)
model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader)
download_queue_service = DownloadQueueService(event_bus=events)
metadata_store = ModelMetadataStore(db=db)
model_install_service = ModelInstallService(
app_config=config,
record_store=model_record_service,
download_queue=download_queue_service,
metadata_store=metadata_store,
metadata_store=ModelMetadataStore(db=db),
event_bus=events,
)
model_manager = ModelManagerService(config, logger) # TO DO: legacy model manager v1. Remove
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()

View File

@ -237,6 +237,7 @@ class InvokeAIAppConfig(InvokeAISettings):
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths)
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
@ -262,6 +263,8 @@ class InvokeAIAppConfig(InvokeAISettings):
# CACHE
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
convert_cache : float = Field(default=10.0, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache)
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
@ -404,6 +407,11 @@ class InvokeAIAppConfig(InvokeAISettings):
"""Path to the models directory."""
return self._resolve(self.models_dir)
@property
def models_convert_cache_path(self) -> Path:
"""Path to the converted cache models directory."""
return self._resolve(self.convert_cache_dir)
@property
def custom_nodes_path(self) -> Path:
"""Path to the custom nodes directory."""
@ -433,15 +441,20 @@ class InvokeAIAppConfig(InvokeAISettings):
return True
@property
def ram_cache_size(self) -> Union[Literal["auto"], float]:
"""Return the ram cache size using the legacy or modern setting."""
def ram_cache_size(self) -> float:
"""Return the ram cache size using the legacy or modern setting (GB)."""
return self.max_cache_size or self.ram
@property
def vram_cache_size(self) -> Union[Literal["auto"], float]:
"""Return the vram cache size using the legacy or modern setting."""
def vram_cache_size(self) -> float:
"""Return the vram cache size using the legacy or modern setting (GB)."""
return self.max_vram_cache_size or self.vram
@property
def convert_cache_size(self) -> float:
"""Return the convert cache size on disk (GB)."""
return self.convert_cache
@property
def use_cpu(self) -> bool:
"""Return true if the device is set to CPU or the always_use_cpu flag is set."""

View File

@ -145,7 +145,7 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if config.get("source") is None:
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
return self._register(model_path, config)
@ -156,7 +156,7 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if config.get("source") is None:
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
info: AnyModelConfig = self._probe_model(Path(model_path), config)
@ -300,6 +300,7 @@ class ModelInstallService(ModelInstallServiceBase):
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:

View File

@ -11,7 +11,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager import LoadedModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
@ -102,6 +102,19 @@ class ModelRecordServiceBase(ABC):
"""
pass
@abstractmethod
def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel_type: For main (pipeline models), the submodel to fetch
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
pass
@property
@abstractmethod
def metadata_store(self) -> ModelMetadataStore:

View File

@ -42,6 +42,7 @@ Typical usage:
import json
import sqlite3
import time
from math import ceil
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
@ -53,8 +54,10 @@ from invokeai.backend.model_manager.config import (
ModelConfigFactory,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import (
@ -69,16 +72,17 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase):
def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader]=None):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
:param db: Sqlite connection object
:param loader: Initialized model loader object (optional)
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
self._cursor = db.conn.cursor()
self._loader = loader
@property
def db(self) -> SqliteDatabase:
@ -199,7 +203,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
SELECT config, strftime('%s',updated_at) FROM model_config
WHERE id=?;
""",
(key,),
@ -207,9 +211,24 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]))
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel_type: For main (pipeline models), the submodel to fetch.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
if not self._loader:
raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader")
model_config = self.get_model(key)
return self._loader.load_model(model_config, submodel_type)
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
@ -265,12 +284,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
f"""--sql
select config FROM model_config
select config, strftime('%s',updated_at) FROM model_config
{where};
""",
tuple(bindings),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
@ -279,12 +298,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
SELECT config, strftime('%s',updated_at) FROM model_config
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
@ -293,12 +312,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
SELECT config, strftime('%s',updated_at) FROM model_config
WHERE original_hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
return results
@property

View File

@ -8,6 +8,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -33,6 +34,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_3(app_config=config, logger=logger))
migrator.register_migration(build_migration_4())
migrator.register_migration(build_migration_5())
migrator.register_migration(build_migration_6())
migrator.run_migrations()
return db

View File

@ -0,0 +1,44 @@
import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration6Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._recreate_model_triggers(cursor)
def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None:
"""
Adds the timestamp trigger to the model_config table.
This trigger was inadvertently dropped in earlier migration scripts.
"""
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
def build_migration_6() -> Migration:
"""
Build the migration from database version 5 to 6.
This migration does the following:
- Adds the model_config_updated_at trigger if it does not exist
"""
migration_6 = Migration(
from_version=5,
to_version=6,
callback=Migration6Callback(),
)
return migration_6