mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix issues identified during PR review by RyanjDick and brandonrising
- ModelMetadataStoreService is now injected into ModelRecordStoreService (these two services are really joined at the hip, and should someday be merged) - ModelRecordStoreService is now injected into ModelManagerService - Reduced timeout value for the various installer and download wait*() methods - Introduced a Mock modelmanager for testing - Removed bare print() statement with _logger in the install helper backend. - Removed unused code from model loader init file - Made `locker` a private variable in the `LoadedModel` object. - Fixed up model merge frontend (will be deprecated anyway!)
This commit is contained in:
parent
9758082dc5
commit
09e7d35b55
@ -25,6 +25,8 @@ from ..services.invoker import Invoker
|
||||
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_metadata import ModelMetadataStoreSQL
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
@ -83,8 +85,12 @@ class ApiDependencies:
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_metadata_service = ModelMetadataStoreSQL(db=db)
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration, db=db, download_queue=download_queue_service, events=events
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
|
||||
download_queue=download_queue_service,
|
||||
events=events,
|
||||
)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
|
@ -194,7 +194,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||
start = time.time()
|
||||
while not job.in_terminal_state:
|
||||
if self._job_completed_event.wait(timeout=5): # in case we miss an event
|
||||
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
|
||||
self._job_completed_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
|
@ -46,8 +46,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
# This is to handle case of the model manager not being initialized, which happens
|
||||
# during some tests.
|
||||
services = self._invoker.services
|
||||
if services.model_manager is None or services.model_manager.load is None:
|
||||
yield None
|
||||
if not self._stats.get(graph_execution_state_id):
|
||||
# First time we're seeing this graph_execution_state_id.
|
||||
self._stats[graph_execution_state_id] = GraphExecutionStats()
|
||||
|
@ -18,7 +18,9 @@ from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
@ -243,7 +245,7 @@ class ModelInstallServiceBase(ABC):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: ModelMetadataStore,
|
||||
metadata_store: ModelMetadataStoreBase,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
"""
|
||||
|
@ -20,7 +20,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@ -33,7 +33,6 @@ from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
HuggingFaceMetadataFetch,
|
||||
ModelMetadataStore,
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
)
|
||||
@ -65,7 +64,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: Optional[ModelMetadataStore] = None,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
session: Optional[Session] = None,
|
||||
):
|
||||
@ -93,14 +91,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._next_job_id = 0
|
||||
# There may not necessarily be a metadata store initialized
|
||||
# so we create one and initialize it with the same sql database
|
||||
# used by the record store service.
|
||||
if metadata_store:
|
||||
self._metadata_store = metadata_store
|
||||
else:
|
||||
assert isinstance(record_store, ModelRecordServiceSQL)
|
||||
self._metadata_store = ModelMetadataStore(record_store.db)
|
||||
self._metadata_store = record_store.metadata_store # for convenience
|
||||
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
@ -259,7 +250,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
"""Block until all installation jobs are done."""
|
||||
start = time.time()
|
||||
while len(self._download_cache) > 0:
|
||||
if self._downloads_changed_event.wait(timeout=5): # in case we miss an event
|
||||
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
|
||||
self._downloads_changed_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
|
@ -5,7 +5,6 @@ from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
@ -13,8 +12,7 @@ from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from ..model_records import ModelRecordServiceBase
|
||||
from .model_manager_base import ModelManagerServiceBase
|
||||
|
||||
|
||||
@ -64,7 +62,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def build_model_manager(
|
||||
cls,
|
||||
app_config: InvokeAIAppConfig,
|
||||
db: SqliteDatabase,
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
) -> Self:
|
||||
@ -82,19 +80,16 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size
|
||||
)
|
||||
record_store = ModelRecordServiceSQL(db=db)
|
||||
loader = ModelLoadService(
|
||||
app_config=app_config,
|
||||
record_store=record_store,
|
||||
record_store=model_record_service,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
)
|
||||
record_store._loader = loader # yeah, there is a circular reference here
|
||||
installer = ModelInstallService(
|
||||
app_config=app_config,
|
||||
record_store=record_store,
|
||||
record_store=model_record_service,
|
||||
download_queue=download_queue,
|
||||
metadata_store=ModelMetadataStore(db=db),
|
||||
event_bus=events,
|
||||
)
|
||||
return cls(store=record_store, install=installer, load=loader)
|
||||
return cls(store=model_record_service, install=installer, load=loader)
|
||||
|
9
invokeai/app/services/model_metadata/__init__.py
Normal file
9
invokeai/app/services/model_metadata/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
"""Init file for ModelMetadataStoreService module."""
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
from .metadata_store_sql import ModelMetadataStoreSQL
|
||||
|
||||
__all__ = [
|
||||
"ModelMetadataStoreBase",
|
||||
"ModelMetadataStoreSQL",
|
||||
]
|
65
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
65
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Storage for Model Metadata
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class ModelMetadataStoreBase(ABC):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
@abstractmethod
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
222
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
222
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
@ -0,0 +1,222 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
@ -17,7 +17,9 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@ -109,7 +111,7 @@ class ModelRecordServiceBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata_store(self) -> ModelMetadataStore:
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
pass
|
||||
|
||||
|
@ -54,8 +54,9 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
DuplicateModelException,
|
||||
@ -69,7 +70,7 @@ from .model_records_base import (
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
@ -78,6 +79,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = db.conn.cursor()
|
||||
self._metadata_store = metadata_store
|
||||
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
@ -157,7 +159,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
@ -307,9 +309,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
return results
|
||||
|
||||
@property
|
||||
def metadata_store(self) -> ModelMetadataStore:
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
return ModelMetadataStore(self._db)
|
||||
return self._metadata_store
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
@ -330,18 +332,18 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
store = ModelMetadataStore(self._db)
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
keys = store.search_by_tag(tags)
|
||||
return [self.get_model(x) for x in keys]
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
store = ModelMetadataStore(self._db)
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_tags()
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
store = ModelMetadataStore(self._db)
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_all_metadata()
|
||||
|
||||
def list_models(
|
||||
|
@ -25,6 +25,7 @@ from invokeai.app.services.model_install import (
|
||||
ModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
@ -45,7 +46,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
|
||||
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
return obj
|
||||
|
||||
|
||||
@ -54,12 +55,10 @@ def initialize_installer(
|
||||
) -> ModelInstallServiceBase:
|
||||
"""Return an initialized ModelInstallService object."""
|
||||
record_store = initialize_record_store(app_config)
|
||||
metadata_store = record_store.metadata_store
|
||||
download_queue = DownloadQueueService()
|
||||
installer = ModelInstallService(
|
||||
app_config=app_config,
|
||||
record_store=record_store,
|
||||
metadata_store=metadata_store,
|
||||
download_queue=download_queue,
|
||||
event_bus=event_bus,
|
||||
)
|
||||
@ -287,14 +286,14 @@ class InstallHelper(object):
|
||||
model_name=model_name,
|
||||
)
|
||||
if len(matches) > 1:
|
||||
print(
|
||||
f"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate."
|
||||
self._logger.error(
|
||||
"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate"
|
||||
)
|
||||
elif not matches:
|
||||
print(f"{model_to_remove}: unknown model")
|
||||
self._logger.error(f"{model_to_remove}: unknown model")
|
||||
else:
|
||||
for m in matches:
|
||||
print(f"Deleting {m.type}:{m.name}")
|
||||
self._logger.info(f"Deleting {m.type}:{m.name}")
|
||||
installer.delete(m.key)
|
||||
|
||||
installer.wait_for_installs()
|
||||
|
@ -4,10 +4,6 @@ Init file for the model loader.
|
||||
"""
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||
from .load_base import AnyModelLoader, LoadedModel
|
||||
@ -19,16 +15,3 @@ for module in loaders:
|
||||
import_module(f"{__package__}.model_loaders.{module}")
|
||||
|
||||
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"]
|
||||
|
||||
|
||||
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
|
||||
app_config = app_config or InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
return AnyModelLoader(
|
||||
app_config=app_config,
|
||||
logger=logger,
|
||||
ram_cache=ModelCache(
|
||||
logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size
|
||||
),
|
||||
convert_cache=ModelConvertCache(app_config.models_convert_cache_path),
|
||||
)
|
||||
|
@ -39,21 +39,21 @@ class LoadedModel:
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
config: AnyModelConfig
|
||||
locker: ModelLockerBase
|
||||
_locker: ModelLockerBase
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
self.locker.lock()
|
||||
self._locker.lock()
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Context exit."""
|
||||
self.locker.unlock()
|
||||
self._locker.unlock()
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without locking it."""
|
||||
return self.locker.model
|
||||
return self._locker.model
|
||||
|
||||
|
||||
class ModelLoaderBase(ABC):
|
||||
|
@ -75,7 +75,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
|
||||
model_path = self._convert_if_needed(model_config, model_path, submodel_type)
|
||||
locker = self._load_if_needed(model_config, model_path, submodel_type)
|
||||
return LoadedModel(config=model_config, locker=locker)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
|
@ -39,10 +39,7 @@ class ModelMerger(object):
|
||||
|
||||
def __init__(self, installer: ModelInstallServiceBase):
|
||||
"""
|
||||
Initialize a ModelMerger object.
|
||||
|
||||
:param store: Underlying storage manager for the running process.
|
||||
:param config: InvokeAIAppConfig object (if not provided, default will be selected).
|
||||
Initialize a ModelMerger object with the model installer.
|
||||
"""
|
||||
self._installer = installer
|
||||
|
||||
|
@ -18,7 +18,7 @@ assert isinstance(data, CivitaiMetadata)
|
||||
if data.allow_commercial_use:
|
||||
print("Commercial use of this model is allowed")
|
||||
"""
|
||||
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch
|
||||
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase
|
||||
from .metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
AnyModelRepoMetadataValidator,
|
||||
@ -31,7 +31,6 @@ from .metadata_base import (
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from .metadata_store import ModelMetadataStore
|
||||
|
||||
__all__ = [
|
||||
"AnyModelRepoMetadata",
|
||||
@ -42,7 +41,7 @@ __all__ = [
|
||||
"HuggingFaceMetadata",
|
||||
"HuggingFaceMetadataFetch",
|
||||
"LicenseRestrictions",
|
||||
"ModelMetadataStore",
|
||||
"ModelMetadataFetchBase",
|
||||
"BaseMetadata",
|
||||
"ModelMetadataWithFiles",
|
||||
"RemoteModelFile",
|
||||
|
@ -6,20 +6,40 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
import argparse
|
||||
import curses
|
||||
import re
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import npyscreen
|
||||
from npyscreen import widget
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.model_install import ModelInstallService
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import ModelMerger
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
BASE_TYPES = [
|
||||
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
|
||||
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
|
||||
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
|
||||
]
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
@ -48,7 +68,7 @@ def _parse_args() -> Namespace:
|
||||
parser.add_argument(
|
||||
"--base_model",
|
||||
type=str,
|
||||
choices=[x.value for x in BaseModelType],
|
||||
choices=[x[0].value for x in BASE_TYPES],
|
||||
help="The base model shared by the models to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -98,17 +118,17 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
@property
|
||||
def model_manager(self):
|
||||
return self.parentApp.model_manager
|
||||
def record_store(self):
|
||||
return self.parentApp.record_store
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
|
||||
self.model_names = self.get_model_names()
|
||||
self.current_base = 0
|
||||
self.models = self.get_models(BASE_TYPES[self.current_base][0])
|
||||
self.model_names = [x[1] for x in self.models]
|
||||
max_width = max([len(x) for x in self.model_names])
|
||||
max_width += 6
|
||||
horizontal_layout = max_width * 3 < window_width
|
||||
@ -128,11 +148,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
self.nextrely += 1
|
||||
self.base_select = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=[
|
||||
"Models Built on SD-1.x",
|
||||
"Models Built on SD-2.x",
|
||||
"Models Built on SDXL",
|
||||
],
|
||||
values=[x[1] for x in BASE_TYPES],
|
||||
value=[self.current_base],
|
||||
columns=4,
|
||||
max_height=2,
|
||||
@ -263,21 +279,20 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
sys.exit(0)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
model_names = self.model_names
|
||||
model_keys = [x[0] for x in self.models]
|
||||
models = [
|
||||
model_names[self.model1.value[0]],
|
||||
model_names[self.model2.value[0]],
|
||||
model_keys[self.model1.value[0]],
|
||||
model_keys[self.model2.value[0]],
|
||||
]
|
||||
if self.model3.value[0] > 0:
|
||||
models.append(model_names[self.model3.value[0] - 1])
|
||||
models.append(model_keys[self.model3.value[0] - 1])
|
||||
interp = "add_difference"
|
||||
else:
|
||||
interp = self.interpolations[self.merge_method.value[0]]
|
||||
|
||||
bases = ["sd-1", "sd-2", "sdxl"]
|
||||
args = {
|
||||
"model_names": models,
|
||||
"base_model": BaseModelType(bases[self.base_select.value[0]]),
|
||||
"model_keys": models,
|
||||
"base_model": tuple(BaseModelType)[self.base_select.value[0]],
|
||||
"alpha": self.alpha.value,
|
||||
"interp": interp,
|
||||
"force": self.force.value,
|
||||
@ -311,18 +326,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self, base_model: BaseModelType = BaseModelType.StableDiffusion1) -> List[str]:
|
||||
model_names = [
|
||||
info["model_name"]
|
||||
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
||||
if info["model_format"] == "diffusers"
|
||||
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
|
||||
models = [
|
||||
(x.key, x.name)
|
||||
for x in self.record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model)
|
||||
if x.format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal")
|
||||
]
|
||||
return sorted(model_names)
|
||||
return sorted(models, key=lambda x: x[1])
|
||||
|
||||
def _populate_models(self, value=None):
|
||||
bases = ["sd-1", "sd-2", "sdxl"]
|
||||
base_model = BaseModelType(bases[value[0]])
|
||||
self.model_names = self.get_model_names(base_model)
|
||||
def _populate_models(self, value: List[int]):
|
||||
base_model = BASE_TYPES[value[0]][0]
|
||||
self.models = self.get_models(base_model)
|
||||
self.model_names = [x[1] for x in self.models]
|
||||
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
@ -334,24 +349,24 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self, model_manager: ModelManager):
|
||||
def __init__(self, record_store: ModelRecordServiceBase):
|
||||
super().__init__()
|
||||
self.model_manager = model_manager
|
||||
self.record_store = record_store
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
||||
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
|
||||
|
||||
|
||||
def run_gui(args: Namespace):
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
mergeapp = Mergeapp(model_manager)
|
||||
def run_gui(args: Namespace) -> None:
|
||||
record_store: ModelRecordServiceBase = get_config_store()
|
||||
mergeapp = Mergeapp(record_store)
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
merger = ModelMerger(model_manager)
|
||||
merger = get_model_merger(record_store)
|
||||
merger.merge_diffusion_models_and_save(**args)
|
||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
||||
merged_model_name = args["merged_model_name"]
|
||||
logger.info(f'Models merged into new model: "{merged_model_name}".')
|
||||
|
||||
|
||||
def run_cli(args: Namespace):
|
||||
@ -364,20 +379,54 @@ def run_cli(args: Namespace):
|
||||
args.merged_model_name = "+".join(args.model_names)
|
||||
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
|
||||
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
record_store: ModelRecordServiceBase = get_config_store()
|
||||
assert (
|
||||
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
|
||||
len(record_store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merger = ModelMerger(model_manager)
|
||||
merger.merge_diffusion_models_and_save(**vars(args))
|
||||
merger = get_model_merger(record_store)
|
||||
model_keys = []
|
||||
for name in args.model_names:
|
||||
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
|
||||
model_keys.append(name)
|
||||
else:
|
||||
models = record_store.search_by_attr(
|
||||
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
|
||||
)
|
||||
assert len(models) > 0, f"{name}: Unknown model"
|
||||
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
|
||||
model_keys.append(models[0].key)
|
||||
|
||||
merger.merge_diffusion_models_and_save(
|
||||
alpha=args.alpha,
|
||||
model_keys=model_keys,
|
||||
merged_model_name=args.merged_model_name,
|
||||
interp=args.interp,
|
||||
force=args.force,
|
||||
)
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
|
||||
|
||||
def get_config_store() -> ModelRecordServiceSQL:
|
||||
output_path = config.output_path
|
||||
assert output_path is not None
|
||||
image_files = DiskImageFileStorage(output_path / "images")
|
||||
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
|
||||
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
|
||||
|
||||
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:
|
||||
installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=DownloadQueueService())
|
||||
installer.start()
|
||||
return ModelMerger(installer)
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
if args.root_dir:
|
||||
config.parse_args(["--root", str(args.root_dir)])
|
||||
else:
|
||||
config.parse_args([])
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
|
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
@ -65,7 +66,7 @@ def mock_services() -> InvocationServices:
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||
latents=None, # type: ignore
|
||||
logger=logging, # type: ignore
|
||||
model_manager=None, # type: ignore
|
||||
model_manager=Mock(), # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
|
@ -8,6 +8,7 @@ from typing import Any
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
ModelRecordOrderBy,
|
||||
@ -36,7 +37,7 @@ def store(
|
||||
config = InvokeAIAppConfig(root=datadir)
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = create_mock_sqlite_database(config, logger)
|
||||
return ModelRecordServiceSQL(db)
|
||||
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
|
||||
|
||||
def example_config() -> TextualInversionConfig:
|
||||
|
@ -14,6 +14,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
@ -21,7 +22,6 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
|
||||
RepoCivitaiModelMetadata1,
|
||||
@ -104,7 +104,7 @@ def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordS
|
||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
|
||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
store = ModelRecordServiceSQL(db)
|
||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
# add five simple config records to the database
|
||||
raw1 = {
|
||||
"path": "/tmp/foo1",
|
||||
@ -163,15 +163,14 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
|
||||
db = mm2_record_store._db # to ensure we are sharing the same database
|
||||
return ModelMetadataStore(db)
|
||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
|
||||
return mm2_record_store.metadata_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
"""This fixtures defines a series of mock URLs for testing download and installation."""
|
||||
sess = TestSession()
|
||||
sess: Session = TestSession()
|
||||
sess.mount(
|
||||
"https://test.com/missing_model.safetensors",
|
||||
TestAdapter(
|
||||
@ -258,8 +257,7 @@ def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> Mo
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = DummyEventService()
|
||||
store = ModelRecordServiceSQL(db)
|
||||
metadata_store = ModelMetadataStore(db)
|
||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
|
||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||
download_queue.start()
|
||||
@ -268,7 +266,6 @@ def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> Mo
|
||||
app_config=mm2_app_config,
|
||||
record_store=store,
|
||||
download_queue=download_queue,
|
||||
metadata_store=metadata_store,
|
||||
event_bus=events,
|
||||
session=mm2_session,
|
||||
)
|
||||
|
@ -8,6 +8,7 @@ import pytest
|
||||
from pydantic.networks import HttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
CivitaiMetadata,
|
||||
@ -15,14 +16,13 @@ from invokeai.backend.model_manager.metadata import (
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
HuggingFaceMetadataFetch,
|
||||
ModelMetadataStore,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from invokeai.backend.model_manager.util import select_hf_files
|
||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||
|
||||
|
||||
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
tags = {"text-to-image", "diffusers"}
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
@ -40,7 +40,7 @@ def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
assert mm2_metadata_store.list_tags() == tags
|
||||
|
||||
|
||||
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
@ -57,7 +57,7 @@ def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
assert input_metadata == output_metadata
|
||||
|
||||
|
||||
def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
metadata1 = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
|
Loading…
Reference in New Issue
Block a user