Fix issues identified during PR review by RyanjDick and brandonrising

- ModelMetadataStoreService is now injected into ModelRecordStoreService
  (these two services are really joined at the hip, and should someday be merged)
- ModelRecordStoreService is now injected into ModelManagerService
- Reduced timeout value for the various installer and download wait*() methods
- Introduced a Mock modelmanager for testing
- Removed bare print() statement with _logger in the install helper backend.
- Removed unused code from model loader init file
- Made `locker` a private variable in the `LoadedModel` object.
- Fixed up model merge frontend (will be deprecated anyway!)
This commit is contained in:
Lincoln Stein 2024-02-15 22:41:29 -05:00 committed by psychedelicious
parent f1597bd6da
commit 996eb96b4e
22 changed files with 449 additions and 131 deletions

View File

@ -28,6 +28,8 @@ from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_metadata import ModelMetadataStoreSQL
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
@ -94,8 +96,12 @@ class ApiDependencies:
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
download_queue_service = DownloadQueueService(event_bus=events)
model_metadata_service = ModelMetadataStoreSQL(db=db)
model_manager = ModelManagerService.build_model_manager(
app_config=configuration, db=db, download_queue=download_queue_service, events=events
app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
download_queue=download_queue_service,
events=events,
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()

View File

@ -194,7 +194,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time()
while not job.in_terminal_state:
if self._job_completed_event.wait(timeout=5): # in case we miss an event
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
self._job_completed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")

View File

@ -46,8 +46,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
# This is to handle case of the model manager not being initialized, which happens
# during some tests.
services = self._invoker.services
if services.model_manager is None or services.model_manager.load is None:
yield None
if not self._stats.get(graph_execution_state_id):
# First time we're seeing this graph_execution_state_id.
self._stats[graph_execution_state_id] = GraphExecutionStats()

View File

@ -18,7 +18,9 @@ from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
class InstallStatus(str, Enum):
@ -243,7 +245,7 @@ class ModelInstallServiceBase(ABC):
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
metadata_store: ModelMetadataStore,
metadata_store: ModelMetadataStoreBase,
event_bus: Optional["EventServiceBase"] = None,
):
"""

View File

@ -20,7 +20,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@ -33,7 +33,6 @@ from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
CivitaiMetadataFetch,
HuggingFaceMetadataFetch,
ModelMetadataStore,
ModelMetadataWithFiles,
RemoteModelFile,
)
@ -65,7 +64,6 @@ class ModelInstallService(ModelInstallServiceBase):
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
metadata_store: Optional[ModelMetadataStore] = None,
event_bus: Optional[EventServiceBase] = None,
session: Optional[Session] = None,
):
@ -93,14 +91,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._running = False
self._session = session
self._next_job_id = 0
# There may not necessarily be a metadata store initialized
# so we create one and initialize it with the same sql database
# used by the record store service.
if metadata_store:
self._metadata_store = metadata_store
else:
assert isinstance(record_store, ModelRecordServiceSQL)
self._metadata_store = ModelMetadataStore(record_store.db)
self._metadata_store = record_store.metadata_store # for convenience
@property
def app_config(self) -> InvokeAIAppConfig: # noqa D102
@ -259,7 +250,7 @@ class ModelInstallService(ModelInstallServiceBase):
"""Block until all installation jobs are done."""
start = time.time()
while len(self._download_cache) > 0:
if self._downloads_changed_event.wait(timeout=5): # in case we miss an event
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
self._downloads_changed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")

View File

@ -5,7 +5,6 @@ from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
@ -13,8 +12,7 @@ from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallService, ModelInstallServiceBase
from ..model_load import ModelLoadService, ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from ..shared.sqlite.sqlite_database import SqliteDatabase
from ..model_records import ModelRecordServiceBase
from .model_manager_base import ModelManagerServiceBase
@ -64,7 +62,7 @@ class ModelManagerService(ModelManagerServiceBase):
def build_model_manager(
cls,
app_config: InvokeAIAppConfig,
db: SqliteDatabase,
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
) -> Self:
@ -82,19 +80,16 @@ class ModelManagerService(ModelManagerServiceBase):
convert_cache = ModelConvertCache(
cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size
)
record_store = ModelRecordServiceSQL(db=db)
loader = ModelLoadService(
app_config=app_config,
record_store=record_store,
record_store=model_record_service,
ram_cache=ram_cache,
convert_cache=convert_cache,
)
record_store._loader = loader # yeah, there is a circular reference here
installer = ModelInstallService(
app_config=app_config,
record_store=record_store,
record_store=model_record_service,
download_queue=download_queue,
metadata_store=ModelMetadataStore(db=db),
event_bus=events,
)
return cls(store=record_store, install=installer, load=loader)
return cls(store=model_record_service, install=installer, load=loader)

View File

@ -0,0 +1,9 @@
"""Init file for ModelMetadataStoreService module."""
from .metadata_store_base import ModelMetadataStoreBase
from .metadata_store_sql import ModelMetadataStoreSQL
__all__ = [
"ModelMetadataStoreBase",
"ModelMetadataStoreSQL",
]

View File

@ -0,0 +1,65 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Storage for Model Metadata
"""
from abc import ABC, abstractmethod
from typing import List, Set, Tuple
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class ModelMetadataStoreBase(ABC):
"""Store, search and fetch model metadata retrieved from remote repositories."""
@abstractmethod
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
@abstractmethod
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
@abstractmethod
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
@abstractmethod
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
@abstractmethod
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
@abstractmethod
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
@abstractmethod
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
@abstractmethod
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""

View File

@ -0,0 +1,222 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
from .metadata_store_base import ModelMetadataStoreBase
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -17,7 +17,9 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
class DuplicateModelException(Exception):
@ -109,7 +111,7 @@ class ModelRecordServiceBase(ABC):
@property
@abstractmethod
def metadata_store(self) -> ModelMetadataStore:
def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database."""
pass

View File

@ -54,8 +54,9 @@ from invokeai.backend.model_manager.config import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import (
DuplicateModelException,
@ -69,7 +70,7 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase):
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
@ -78,6 +79,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
super().__init__()
self._db = db
self._cursor = db.conn.cursor()
self._metadata_store = metadata_store
@property
def db(self) -> SqliteDatabase:
@ -157,7 +159,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback()
raise e
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
"""
Update the model, returning the updated version.
@ -307,9 +309,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return results
@property
def metadata_store(self) -> ModelMetadataStore:
def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database."""
return ModelMetadataStore(self._db)
return self._metadata_store
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
@ -330,18 +332,18 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param tags: Set of tags to search for. All tags must be present.
"""
store = ModelMetadataStore(self._db)
store = ModelMetadataStoreSQL(self._db)
keys = store.search_by_tag(tags)
return [self.get_model(x) for x in keys]
def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database."""
store = ModelMetadataStore(self._db)
store = ModelMetadataStoreSQL(self._db)
return store.list_tags()
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it."""
store = ModelMetadataStore(self._db)
store = ModelMetadataStoreSQL(self._db)
return store.list_all_metadata()
def list_models(

View File

@ -25,6 +25,7 @@ from invokeai.app.services.model_install import (
ModelSource,
URLModelSource,
)
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
@ -45,7 +46,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
logger = InvokeAILogger.get_logger(config=app_config)
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
db = init_db(config=app_config, logger=logger, image_files=image_files)
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
return obj
@ -54,12 +55,10 @@ def initialize_installer(
) -> ModelInstallServiceBase:
"""Return an initialized ModelInstallService object."""
record_store = initialize_record_store(app_config)
metadata_store = record_store.metadata_store
download_queue = DownloadQueueService()
installer = ModelInstallService(
app_config=app_config,
record_store=record_store,
metadata_store=metadata_store,
download_queue=download_queue,
event_bus=event_bus,
)
@ -287,14 +286,14 @@ class InstallHelper(object):
model_name=model_name,
)
if len(matches) > 1:
print(
f"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate."
self._logger.error(
"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate"
)
elif not matches:
print(f"{model_to_remove}: unknown model")
self._logger.error(f"{model_to_remove}: unknown model")
else:
for m in matches:
print(f"Deleting {m.type}:{m.name}")
self._logger.info(f"Deleting {m.type}:{m.name}")
installer.delete(m.key)
installer.wait_for_installs()

View File

@ -4,10 +4,6 @@ Init file for the model loader.
"""
from importlib import import_module
from pathlib import Path
from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from .convert_cache.convert_cache_default import ModelConvertCache
from .load_base import AnyModelLoader, LoadedModel
@ -19,16 +15,3 @@ for module in loaders:
import_module(f"{__package__}.model_loaders.{module}")
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"]
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
app_config = app_config or InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger(config=app_config)
return AnyModelLoader(
app_config=app_config,
logger=logger,
ram_cache=ModelCache(
logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size
),
convert_cache=ModelConvertCache(app_config.models_convert_cache_path),
)

View File

@ -39,21 +39,21 @@ class LoadedModel:
"""Context manager object that mediates transfer from RAM<->VRAM."""
config: AnyModelConfig
locker: ModelLockerBase
_locker: ModelLockerBase
def __enter__(self) -> AnyModel:
"""Context entry."""
self.locker.lock()
self._locker.lock()
return self.model
def __exit__(self, *args: Any, **kwargs: Any) -> None:
"""Context exit."""
self.locker.unlock()
self._locker.unlock()
@property
def model(self) -> AnyModel:
"""Return the model without locking it."""
return self.locker.model
return self._locker.model
class ModelLoaderBase(ABC):

View File

@ -75,7 +75,7 @@ class ModelLoader(ModelLoaderBase):
model_path = self._convert_if_needed(model_config, model_path, submodel_type)
locker = self._load_if_needed(model_config, model_path, submodel_type)
return LoadedModel(config=model_config, locker=locker)
return LoadedModel(config=model_config, _locker=locker)
def _get_model_path(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None

View File

@ -39,10 +39,7 @@ class ModelMerger(object):
def __init__(self, installer: ModelInstallServiceBase):
"""
Initialize a ModelMerger object.
:param store: Underlying storage manager for the running process.
:param config: InvokeAIAppConfig object (if not provided, default will be selected).
Initialize a ModelMerger object with the model installer.
"""
self._installer = installer

View File

@ -18,7 +18,7 @@ assert isinstance(data, CivitaiMetadata)
if data.allow_commercial_use:
print("Commercial use of this model is allowed")
"""
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase
from .metadata_base import (
AnyModelRepoMetadata,
AnyModelRepoMetadataValidator,
@ -31,7 +31,6 @@ from .metadata_base import (
RemoteModelFile,
UnknownMetadataException,
)
from .metadata_store import ModelMetadataStore
__all__ = [
"AnyModelRepoMetadata",
@ -42,7 +41,7 @@ __all__ = [
"HuggingFaceMetadata",
"HuggingFaceMetadataFetch",
"LicenseRestrictions",
"ModelMetadataStore",
"ModelMetadataFetchBase",
"BaseMetadata",
"ModelMetadataWithFiles",
"RemoteModelFile",

View File

@ -6,20 +6,40 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
import curses
import re
import sys
from argparse import Namespace
from pathlib import Path
from typing import List
from typing import List, Optional, Tuple
import npyscreen
from npyscreen import widget
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
BaseModelType,
ModelFormat,
ModelType,
ModelVariantType,
)
from invokeai.backend.model_manager.merge import ModelMerger
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger()
BASE_TYPES = [
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
]
def _parse_args() -> Namespace:
@ -48,7 +68,7 @@ def _parse_args() -> Namespace:
parser.add_argument(
"--base_model",
type=str,
choices=[x.value for x in BaseModelType],
choices=[x[0].value for x in BASE_TYPES],
help="The base model shared by the models to be merged",
)
parser.add_argument(
@ -98,17 +118,17 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
super().__init__(parentApp, name)
@property
def model_manager(self):
return self.parentApp.model_manager
def record_store(self):
return self.parentApp.record_store
def afterEditing(self):
self.parentApp.setNextForm(None)
def create(self):
window_height, window_width = curses.initscr().getmaxyx()
self.model_names = self.get_model_names()
self.current_base = 0
self.models = self.get_models(BASE_TYPES[self.current_base][0])
self.model_names = [x[1] for x in self.models]
max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width * 3 < window_width
@ -128,11 +148,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
self.nextrely += 1
self.base_select = self.add_widget_intelligent(
SingleSelectColumns,
values=[
"Models Built on SD-1.x",
"Models Built on SD-2.x",
"Models Built on SDXL",
],
values=[x[1] for x in BASE_TYPES],
value=[self.current_base],
columns=4,
max_height=2,
@ -263,21 +279,20 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
sys.exit(0)
def marshall_arguments(self) -> dict:
model_names = self.model_names
model_keys = [x[0] for x in self.models]
models = [
model_names[self.model1.value[0]],
model_names[self.model2.value[0]],
model_keys[self.model1.value[0]],
model_keys[self.model2.value[0]],
]
if self.model3.value[0] > 0:
models.append(model_names[self.model3.value[0] - 1])
models.append(model_keys[self.model3.value[0] - 1])
interp = "add_difference"
else:
interp = self.interpolations[self.merge_method.value[0]]
bases = ["sd-1", "sd-2", "sdxl"]
args = {
"model_names": models,
"base_model": BaseModelType(bases[self.base_select.value[0]]),
"model_keys": models,
"base_model": tuple(BaseModelType)[self.base_select.value[0]],
"alpha": self.alpha.value,
"interp": interp,
"force": self.force.value,
@ -311,18 +326,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else:
return True
def get_model_names(self, base_model: BaseModelType = BaseModelType.StableDiffusion1) -> List[str]:
model_names = [
info["model_name"]
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
if info["model_format"] == "diffusers"
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
models = [
(x.key, x.name)
for x in self.record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model)
if x.format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal")
]
return sorted(model_names)
return sorted(models, key=lambda x: x[1])
def _populate_models(self, value=None):
bases = ["sd-1", "sd-2", "sdxl"]
base_model = BaseModelType(bases[value[0]])
self.model_names = self.get_model_names(base_model)
def _populate_models(self, value: List[int]):
base_model = BASE_TYPES[value[0]][0]
self.models = self.get_models(base_model)
self.model_names = [x[1] for x in self.models]
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
@ -334,24 +349,24 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self, model_manager: ModelManager):
def __init__(self, record_store: ModelRecordServiceBase):
super().__init__()
self.model_manager = model_manager
self.record_store = record_store
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
def run_gui(args: Namespace):
model_manager = ModelManager(config.model_conf_path)
mergeapp = Mergeapp(model_manager)
def run_gui(args: Namespace) -> None:
record_store: ModelRecordServiceBase = get_config_store()
mergeapp = Mergeapp(record_store)
mergeapp.run()
args = mergeapp.merge_arguments
merger = ModelMerger(model_manager)
merger = get_model_merger(record_store)
merger.merge_diffusion_models_and_save(**args)
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
merged_model_name = args["merged_model_name"]
logger.info(f'Models merged into new model: "{merged_model_name}".')
def run_cli(args: Namespace):
@ -364,20 +379,54 @@ def run_cli(args: Namespace):
args.merged_model_name = "+".join(args.model_names)
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
model_manager = ModelManager(config.model_conf_path)
record_store: ModelRecordServiceBase = get_config_store()
assert (
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
len(record_store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**vars(args))
merger = get_model_merger(record_store)
model_keys = []
for name in args.model_names:
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
model_keys.append(name)
else:
models = record_store.search_by_attr(
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
)
assert len(models) > 0, f"{name}: Unknown model"
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
model_keys.append(models[0].key)
merger.merge_diffusion_models_and_save(
alpha=args.alpha,
model_keys=model_keys,
merged_model_name=args.merged_model_name,
interp=args.interp,
force=args.force,
)
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
def get_config_store() -> ModelRecordServiceSQL:
output_path = config.output_path
assert output_path is not None
image_files = DiskImageFileStorage(output_path / "images")
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:
installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=DownloadQueueService())
installer.start()
return ModelMerger(installer)
def main():
args = _parse_args()
if args.root_dir:
config.parse_args(["--root", str(args.root_dir)])
else:
config.parse_args([])
try:
if args.front_end:

View File

@ -1,4 +1,5 @@
import logging
from unittest.mock import Mock
import pytest
@ -64,7 +65,7 @@ def mock_services() -> InvocationServices:
images=None, # type: ignore
invocation_cache=MemoryInvocationCache(max_cache_size=0),
logger=logging, # type: ignore
model_manager=None, # type: ignore
model_manager=Mock(), # type: ignore
download_queue=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),

View File

@ -8,6 +8,7 @@ from typing import Any
import pytest
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordOrderBy,
@ -36,7 +37,7 @@ def store(
config = InvokeAIAppConfig(root=datadir)
logger = InvokeAILogger.get_logger(config=config)
db = create_mock_sqlite_database(config, logger)
return ModelRecordServiceSQL(db)
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
def example_config() -> TextualInversionConfig:

View File

@ -14,6 +14,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.backend.model_manager.config import (
BaseModelType,
@ -21,7 +22,6 @@ from invokeai.backend.model_manager.config import (
ModelType,
)
from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
RepoCivitaiModelMetadata1,
@ -104,7 +104,7 @@ def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordS
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
logger = InvokeAILogger.get_logger(config=mm2_app_config)
db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db)
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
# add five simple config records to the database
raw1 = {
"path": "/tmp/foo1",
@ -163,15 +163,14 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL
@pytest.fixture
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
db = mm2_record_store._db # to ensure we are sharing the same database
return ModelMetadataStore(db)
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
return mm2_record_store.metadata_store
@pytest.fixture
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
"""This fixtures defines a series of mock URLs for testing download and installation."""
sess = TestSession()
sess: Session = TestSession()
sess.mount(
"https://test.com/missing_model.safetensors",
TestAdapter(
@ -258,8 +257,7 @@ def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> Mo
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService()
store = ModelRecordServiceSQL(db)
metadata_store = ModelMetadataStore(db)
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start()
@ -268,7 +266,6 @@ def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> Mo
app_config=mm2_app_config,
record_store=store,
download_queue=download_queue,
metadata_store=metadata_store,
event_bus=events,
session=mm2_session,
)

View File

@ -8,6 +8,7 @@ import pytest
from pydantic.networks import HttpUrl
from requests.sessions import Session
from invokeai.app.services.model_metadata import ModelMetadataStoreBase
from invokeai.backend.model_manager.config import ModelRepoVariant
from invokeai.backend.model_manager.metadata import (
CivitaiMetadata,
@ -15,14 +16,13 @@ from invokeai.backend.model_manager.metadata import (
CommercialUsage,
HuggingFaceMetadata,
HuggingFaceMetadataFetch,
ModelMetadataStore,
UnknownMetadataException,
)
from invokeai.backend.model_manager.util import select_hf_files
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
tags = {"text-to-image", "diffusers"}
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
@ -40,7 +40,7 @@ def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
assert mm2_metadata_store.list_tags() == tags
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None:
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
@ -57,7 +57,7 @@ def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
assert input_metadata == output_metadata
def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None:
def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None:
metadata1 = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",