mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix issues identified during PR review by RyanjDick and brandonrising
- ModelMetadataStoreService is now injected into ModelRecordStoreService (these two services are really joined at the hip, and should someday be merged) - ModelRecordStoreService is now injected into ModelManagerService - Reduced timeout value for the various installer and download wait*() methods - Introduced a Mock modelmanager for testing - Removed bare print() statement with _logger in the install helper backend. - Removed unused code from model loader init file - Made `locker` a private variable in the `LoadedModel` object. - Fixed up model merge frontend (will be deprecated anyway!)
This commit is contained in:
parent
9758082dc5
commit
09e7d35b55
@ -25,6 +25,8 @@ from ..services.invoker import Invoker
|
|||||||
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||||
|
from ..services.model_metadata import ModelMetadataStoreSQL
|
||||||
|
from ..services.model_records import ModelRecordServiceSQL
|
||||||
from ..services.names.names_default import SimpleNameService
|
from ..services.names.names_default import SimpleNameService
|
||||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||||
@ -83,8 +85,12 @@ class ApiDependencies:
|
|||||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||||
download_queue_service = DownloadQueueService(event_bus=events)
|
download_queue_service = DownloadQueueService(event_bus=events)
|
||||||
|
model_metadata_service = ModelMetadataStoreSQL(db=db)
|
||||||
model_manager = ModelManagerService.build_model_manager(
|
model_manager = ModelManagerService.build_model_manager(
|
||||||
app_config=configuration, db=db, download_queue=download_queue_service, events=events
|
app_config=configuration,
|
||||||
|
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
|
||||||
|
download_queue=download_queue_service,
|
||||||
|
events=events,
|
||||||
)
|
)
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
performance_statistics = InvocationStatsService()
|
performance_statistics = InvocationStatsService()
|
||||||
|
@ -194,7 +194,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||||
start = time.time()
|
start = time.time()
|
||||||
while not job.in_terminal_state:
|
while not job.in_terminal_state:
|
||||||
if self._job_completed_event.wait(timeout=5): # in case we miss an event
|
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
|
||||||
self._job_completed_event.clear()
|
self._job_completed_event.clear()
|
||||||
if timeout > 0 and time.time() - start > timeout:
|
if timeout > 0 and time.time() - start > timeout:
|
||||||
raise TimeoutError("Timeout exceeded")
|
raise TimeoutError("Timeout exceeded")
|
||||||
|
@ -46,8 +46,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
# This is to handle case of the model manager not being initialized, which happens
|
# This is to handle case of the model manager not being initialized, which happens
|
||||||
# during some tests.
|
# during some tests.
|
||||||
services = self._invoker.services
|
services = self._invoker.services
|
||||||
if services.model_manager is None or services.model_manager.load is None:
|
|
||||||
yield None
|
|
||||||
if not self._stats.get(graph_execution_state_id):
|
if not self._stats.get(graph_execution_state_id):
|
||||||
# First time we're seeing this graph_execution_state_id.
|
# First time we're seeing this graph_execution_state_id.
|
||||||
self._stats[graph_execution_state_id] = GraphExecutionStats()
|
self._stats[graph_execution_state_id] = GraphExecutionStats()
|
||||||
|
@ -18,7 +18,9 @@ from invokeai.app.services.events import EventServiceBase
|
|||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
|
|
||||||
|
from ..model_metadata import ModelMetadataStoreBase
|
||||||
|
|
||||||
|
|
||||||
class InstallStatus(str, Enum):
|
class InstallStatus(str, Enum):
|
||||||
@ -243,7 +245,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
app_config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
record_store: ModelRecordServiceBase,
|
record_store: ModelRecordServiceBase,
|
||||||
download_queue: DownloadQueueServiceBase,
|
download_queue: DownloadQueueServiceBase,
|
||||||
metadata_store: ModelMetadataStore,
|
metadata_store: ModelMetadataStoreBase,
|
||||||
event_bus: Optional["EventServiceBase"] = None,
|
event_bus: Optional["EventServiceBase"] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -20,7 +20,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL
|
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -33,7 +33,6 @@ from invokeai.backend.model_manager.metadata import (
|
|||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
CivitaiMetadataFetch,
|
CivitaiMetadataFetch,
|
||||||
HuggingFaceMetadataFetch,
|
HuggingFaceMetadataFetch,
|
||||||
ModelMetadataStore,
|
|
||||||
ModelMetadataWithFiles,
|
ModelMetadataWithFiles,
|
||||||
RemoteModelFile,
|
RemoteModelFile,
|
||||||
)
|
)
|
||||||
@ -65,7 +64,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
app_config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
record_store: ModelRecordServiceBase,
|
record_store: ModelRecordServiceBase,
|
||||||
download_queue: DownloadQueueServiceBase,
|
download_queue: DownloadQueueServiceBase,
|
||||||
metadata_store: Optional[ModelMetadataStore] = None,
|
|
||||||
event_bus: Optional[EventServiceBase] = None,
|
event_bus: Optional[EventServiceBase] = None,
|
||||||
session: Optional[Session] = None,
|
session: Optional[Session] = None,
|
||||||
):
|
):
|
||||||
@ -93,14 +91,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._running = False
|
self._running = False
|
||||||
self._session = session
|
self._session = session
|
||||||
self._next_job_id = 0
|
self._next_job_id = 0
|
||||||
# There may not necessarily be a metadata store initialized
|
self._metadata_store = record_store.metadata_store # for convenience
|
||||||
# so we create one and initialize it with the same sql database
|
|
||||||
# used by the record store service.
|
|
||||||
if metadata_store:
|
|
||||||
self._metadata_store = metadata_store
|
|
||||||
else:
|
|
||||||
assert isinstance(record_store, ModelRecordServiceSQL)
|
|
||||||
self._metadata_store = ModelMetadataStore(record_store.db)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||||
@ -259,7 +250,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
"""Block until all installation jobs are done."""
|
"""Block until all installation jobs are done."""
|
||||||
start = time.time()
|
start = time.time()
|
||||||
while len(self._download_cache) > 0:
|
while len(self._download_cache) > 0:
|
||||||
if self._downloads_changed_event.wait(timeout=5): # in case we miss an event
|
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
|
||||||
self._downloads_changed_event.clear()
|
self._downloads_changed_event.clear()
|
||||||
if timeout > 0 and time.time() - start > timeout:
|
if timeout > 0 and time.time() - start > timeout:
|
||||||
raise TimeoutError("Timeout exceeded")
|
raise TimeoutError("Timeout exceeded")
|
||||||
|
@ -5,7 +5,6 @@ from typing_extensions import Self
|
|||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from ..config import InvokeAIAppConfig
|
from ..config import InvokeAIAppConfig
|
||||||
@ -13,8 +12,7 @@ from ..download import DownloadQueueServiceBase
|
|||||||
from ..events.events_base import EventServiceBase
|
from ..events.events_base import EventServiceBase
|
||||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||||
from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
from ..model_records import ModelRecordServiceBase
|
||||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
|
||||||
from .model_manager_base import ModelManagerServiceBase
|
from .model_manager_base import ModelManagerServiceBase
|
||||||
|
|
||||||
|
|
||||||
@ -64,7 +62,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
def build_model_manager(
|
def build_model_manager(
|
||||||
cls,
|
cls,
|
||||||
app_config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
db: SqliteDatabase,
|
model_record_service: ModelRecordServiceBase,
|
||||||
download_queue: DownloadQueueServiceBase,
|
download_queue: DownloadQueueServiceBase,
|
||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
@ -82,19 +80,16 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
convert_cache = ModelConvertCache(
|
convert_cache = ModelConvertCache(
|
||||||
cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size
|
cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size
|
||||||
)
|
)
|
||||||
record_store = ModelRecordServiceSQL(db=db)
|
|
||||||
loader = ModelLoadService(
|
loader = ModelLoadService(
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
record_store=record_store,
|
record_store=model_record_service,
|
||||||
ram_cache=ram_cache,
|
ram_cache=ram_cache,
|
||||||
convert_cache=convert_cache,
|
convert_cache=convert_cache,
|
||||||
)
|
)
|
||||||
record_store._loader = loader # yeah, there is a circular reference here
|
|
||||||
installer = ModelInstallService(
|
installer = ModelInstallService(
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
record_store=record_store,
|
record_store=model_record_service,
|
||||||
download_queue=download_queue,
|
download_queue=download_queue,
|
||||||
metadata_store=ModelMetadataStore(db=db),
|
|
||||||
event_bus=events,
|
event_bus=events,
|
||||||
)
|
)
|
||||||
return cls(store=record_store, install=installer, load=loader)
|
return cls(store=model_record_service, install=installer, load=loader)
|
||||||
|
9
invokeai/app/services/model_metadata/__init__.py
Normal file
9
invokeai/app/services/model_metadata/__init__.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
"""Init file for ModelMetadataStoreService module."""
|
||||||
|
|
||||||
|
from .metadata_store_base import ModelMetadataStoreBase
|
||||||
|
from .metadata_store_sql import ModelMetadataStoreSQL
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModelMetadataStoreBase",
|
||||||
|
"ModelMetadataStoreSQL",
|
||||||
|
]
|
65
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
65
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""
|
||||||
|
Storage for Model Metadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Set, Tuple
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMetadataStoreBase(ABC):
|
||||||
|
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||||
|
"""
|
||||||
|
Add a block of repo metadata to a model record.
|
||||||
|
|
||||||
|
The model record config must already exist in the database with the
|
||||||
|
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||||
|
|
||||||
|
:param model_key: Existing model key in the `model_config` table
|
||||||
|
:param metadata: ModelRepoMetadata object to store
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||||
|
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||||
|
"""Dump out all the metadata."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||||
|
"""
|
||||||
|
Update metadata corresponding to the model with the indicated key.
|
||||||
|
|
||||||
|
:param model_key: Existing model key in the `model_config` table
|
||||||
|
:param metadata: ModelRepoMetadata object to update
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_tags(self) -> Set[str]:
|
||||||
|
"""Return all tags in the tags table."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||||
|
"""Return the keys of models containing all of the listed tags."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_by_author(self, author: str) -> Set[str]:
|
||||||
|
"""Return the keys of models authored by the indicated author."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_by_name(self, name: str) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Return the keys of models with the indicated name.
|
||||||
|
|
||||||
|
Note that this is the name of the model given to it by
|
||||||
|
the remote source. The user may have changed the local
|
||||||
|
name. The local name will be located in the model config
|
||||||
|
record object.
|
||||||
|
"""
|
222
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
222
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""
|
||||||
|
SQL Storage for Model Metadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||||
|
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
|
||||||
|
|
||||||
|
from .metadata_store_base import ModelMetadataStoreBase
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||||
|
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||||
|
|
||||||
|
def __init__(self, db: SqliteDatabase):
|
||||||
|
"""
|
||||||
|
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||||
|
|
||||||
|
:param conn: sqlite3 connection object
|
||||||
|
:param lock: threading Lock object
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._db = db
|
||||||
|
self._cursor = self._db.conn.cursor()
|
||||||
|
|
||||||
|
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||||
|
"""
|
||||||
|
Add a block of repo metadata to a model record.
|
||||||
|
|
||||||
|
The model record config must already exist in the database with the
|
||||||
|
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||||
|
|
||||||
|
:param model_key: Existing model key in the `model_config` table
|
||||||
|
:param metadata: ModelRepoMetadata object to store
|
||||||
|
"""
|
||||||
|
json_serialized = metadata.model_dump_json()
|
||||||
|
with self._db.lock:
|
||||||
|
try:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT INTO model_metadata(
|
||||||
|
id,
|
||||||
|
metadata
|
||||||
|
)
|
||||||
|
VALUES (?,?);
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
model_key,
|
||||||
|
json_serialized,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._update_tags(model_key, metadata.tags)
|
||||||
|
self._db.conn.commit()
|
||||||
|
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||||
|
self._db.conn.rollback()
|
||||||
|
raise UnknownMetadataException from excp
|
||||||
|
except sqlite3.Error as excp:
|
||||||
|
self._db.conn.rollback()
|
||||||
|
raise excp
|
||||||
|
|
||||||
|
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||||
|
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||||
|
with self._db.lock:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT metadata FROM model_metadata
|
||||||
|
WHERE id=?;
|
||||||
|
""",
|
||||||
|
(model_key,),
|
||||||
|
)
|
||||||
|
rows = self._cursor.fetchone()
|
||||||
|
if not rows:
|
||||||
|
raise UnknownMetadataException("model metadata not found")
|
||||||
|
return ModelMetadataFetchBase.from_json(rows[0])
|
||||||
|
|
||||||
|
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||||
|
"""Dump out all the metadata."""
|
||||||
|
with self._db.lock:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT id,metadata FROM model_metadata;
|
||||||
|
""",
|
||||||
|
(),
|
||||||
|
)
|
||||||
|
rows = self._cursor.fetchall()
|
||||||
|
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||||
|
|
||||||
|
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||||
|
"""
|
||||||
|
Update metadata corresponding to the model with the indicated key.
|
||||||
|
|
||||||
|
:param model_key: Existing model key in the `model_config` table
|
||||||
|
:param metadata: ModelRepoMetadata object to update
|
||||||
|
"""
|
||||||
|
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||||
|
with self._db.lock:
|
||||||
|
try:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
UPDATE model_metadata
|
||||||
|
SET
|
||||||
|
metadata=?
|
||||||
|
WHERE id=?;
|
||||||
|
""",
|
||||||
|
(json_serialized, model_key),
|
||||||
|
)
|
||||||
|
if self._cursor.rowcount == 0:
|
||||||
|
raise UnknownMetadataException("model metadata not found")
|
||||||
|
self._update_tags(model_key, metadata.tags)
|
||||||
|
self._db.conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._db.conn.rollback()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return self.get_metadata(model_key)
|
||||||
|
|
||||||
|
def list_tags(self) -> Set[str]:
|
||||||
|
"""Return all tags in the tags table."""
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
select tag_text from tags;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return {x[0] for x in self._cursor.fetchall()}
|
||||||
|
|
||||||
|
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||||
|
"""Return the keys of models containing all of the listed tags."""
|
||||||
|
with self._db.lock:
|
||||||
|
try:
|
||||||
|
matches: Optional[Set[str]] = None
|
||||||
|
for tag in tags:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT a.model_id FROM model_tags AS a,
|
||||||
|
tags AS b
|
||||||
|
WHERE a.tag_id=b.tag_id
|
||||||
|
AND b.tag_text=?;
|
||||||
|
""",
|
||||||
|
(tag,),
|
||||||
|
)
|
||||||
|
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||||
|
if matches is None:
|
||||||
|
matches = model_keys
|
||||||
|
matches = matches.intersection(model_keys)
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
raise e
|
||||||
|
return matches if matches else set()
|
||||||
|
|
||||||
|
def search_by_author(self, author: str) -> Set[str]:
|
||||||
|
"""Return the keys of models authored by the indicated author."""
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT id FROM model_metadata
|
||||||
|
WHERE author=?;
|
||||||
|
""",
|
||||||
|
(author,),
|
||||||
|
)
|
||||||
|
return {x[0] for x in self._cursor.fetchall()}
|
||||||
|
|
||||||
|
def search_by_name(self, name: str) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Return the keys of models with the indicated name.
|
||||||
|
|
||||||
|
Note that this is the name of the model given to it by
|
||||||
|
the remote source. The user may have changed the local
|
||||||
|
name. The local name will be located in the model config
|
||||||
|
record object.
|
||||||
|
"""
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT id FROM model_metadata
|
||||||
|
WHERE name=?;
|
||||||
|
""",
|
||||||
|
(name,),
|
||||||
|
)
|
||||||
|
return {x[0] for x in self._cursor.fetchall()}
|
||||||
|
|
||||||
|
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||||
|
"""Update tags for the model referenced by model_key."""
|
||||||
|
# remove previous tags from this model
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
DELETE FROM model_tags
|
||||||
|
WHERE model_id=?;
|
||||||
|
""",
|
||||||
|
(model_key,),
|
||||||
|
)
|
||||||
|
|
||||||
|
for tag in tags:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO tags (
|
||||||
|
tag_text
|
||||||
|
)
|
||||||
|
VALUES (?);
|
||||||
|
""",
|
||||||
|
(tag,),
|
||||||
|
)
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT tag_id
|
||||||
|
FROM tags
|
||||||
|
WHERE tag_text = ?
|
||||||
|
LIMIT 1;
|
||||||
|
""",
|
||||||
|
(tag,),
|
||||||
|
)
|
||||||
|
tag_id = self._cursor.fetchone()[0]
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO model_tags (
|
||||||
|
model_id,
|
||||||
|
tag_id
|
||||||
|
)
|
||||||
|
VALUES (?,?);
|
||||||
|
""",
|
||||||
|
(model_key, tag_id),
|
||||||
|
)
|
@ -17,7 +17,9 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
|
|
||||||
|
from ..model_metadata import ModelMetadataStoreBase
|
||||||
|
|
||||||
|
|
||||||
class DuplicateModelException(Exception):
|
class DuplicateModelException(Exception):
|
||||||
@ -109,7 +111,7 @@ class ModelRecordServiceBase(ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def metadata_store(self) -> ModelMetadataStore:
|
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||||
"""Return a ModelMetadataStore initialized on the same database."""
|
"""Return a ModelMetadataStore initialized on the same database."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -54,8 +54,9 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||||
|
|
||||||
|
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from .model_records_base import (
|
from .model_records_base import (
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
@ -69,7 +70,7 @@ from .model_records_base import (
|
|||||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase):
|
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
|
||||||
"""
|
"""
|
||||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||||
|
|
||||||
@ -78,6 +79,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._db = db
|
self._db = db
|
||||||
self._cursor = db.conn.cursor()
|
self._cursor = db.conn.cursor()
|
||||||
|
self._metadata_store = metadata_store
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db(self) -> SqliteDatabase:
|
def db(self) -> SqliteDatabase:
|
||||||
@ -157,7 +159,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._db.conn.rollback()
|
self._db.conn.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Update the model, returning the updated version.
|
Update the model, returning the updated version.
|
||||||
|
|
||||||
@ -307,9 +309,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metadata_store(self) -> ModelMetadataStore:
|
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||||
"""Return a ModelMetadataStore initialized on the same database."""
|
"""Return a ModelMetadataStore initialized on the same database."""
|
||||||
return ModelMetadataStore(self._db)
|
return self._metadata_store
|
||||||
|
|
||||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||||
"""
|
"""
|
||||||
@ -330,18 +332,18 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
|
|
||||||
:param tags: Set of tags to search for. All tags must be present.
|
:param tags: Set of tags to search for. All tags must be present.
|
||||||
"""
|
"""
|
||||||
store = ModelMetadataStore(self._db)
|
store = ModelMetadataStoreSQL(self._db)
|
||||||
keys = store.search_by_tag(tags)
|
keys = store.search_by_tag(tags)
|
||||||
return [self.get_model(x) for x in keys]
|
return [self.get_model(x) for x in keys]
|
||||||
|
|
||||||
def list_tags(self) -> Set[str]:
|
def list_tags(self) -> Set[str]:
|
||||||
"""Return a unique set of all the model tags in the metadata database."""
|
"""Return a unique set of all the model tags in the metadata database."""
|
||||||
store = ModelMetadataStore(self._db)
|
store = ModelMetadataStoreSQL(self._db)
|
||||||
return store.list_tags()
|
return store.list_tags()
|
||||||
|
|
||||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||||
"""List metadata for all models that have it."""
|
"""List metadata for all models that have it."""
|
||||||
store = ModelMetadataStore(self._db)
|
store = ModelMetadataStoreSQL(self._db)
|
||||||
return store.list_all_metadata()
|
return store.list_all_metadata()
|
||||||
|
|
||||||
def list_models(
|
def list_models(
|
||||||
|
@ -25,6 +25,7 @@ from invokeai.app.services.model_install import (
|
|||||||
ModelSource,
|
ModelSource,
|
||||||
URLModelSource,
|
URLModelSource,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
@ -45,7 +46,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
|
|||||||
logger = InvokeAILogger.get_logger(config=app_config)
|
logger = InvokeAILogger.get_logger(config=app_config)
|
||||||
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
|
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
|
||||||
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
||||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
@ -54,12 +55,10 @@ def initialize_installer(
|
|||||||
) -> ModelInstallServiceBase:
|
) -> ModelInstallServiceBase:
|
||||||
"""Return an initialized ModelInstallService object."""
|
"""Return an initialized ModelInstallService object."""
|
||||||
record_store = initialize_record_store(app_config)
|
record_store = initialize_record_store(app_config)
|
||||||
metadata_store = record_store.metadata_store
|
|
||||||
download_queue = DownloadQueueService()
|
download_queue = DownloadQueueService()
|
||||||
installer = ModelInstallService(
|
installer = ModelInstallService(
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
record_store=record_store,
|
record_store=record_store,
|
||||||
metadata_store=metadata_store,
|
|
||||||
download_queue=download_queue,
|
download_queue=download_queue,
|
||||||
event_bus=event_bus,
|
event_bus=event_bus,
|
||||||
)
|
)
|
||||||
@ -287,14 +286,14 @@ class InstallHelper(object):
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
if len(matches) > 1:
|
if len(matches) > 1:
|
||||||
print(
|
self._logger.error(
|
||||||
f"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate."
|
"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate"
|
||||||
)
|
)
|
||||||
elif not matches:
|
elif not matches:
|
||||||
print(f"{model_to_remove}: unknown model")
|
self._logger.error(f"{model_to_remove}: unknown model")
|
||||||
else:
|
else:
|
||||||
for m in matches:
|
for m in matches:
|
||||||
print(f"Deleting {m.type}:{m.name}")
|
self._logger.info(f"Deleting {m.type}:{m.name}")
|
||||||
installer.delete(m.key)
|
installer.delete(m.key)
|
||||||
|
|
||||||
installer.wait_for_installs()
|
installer.wait_for_installs()
|
||||||
|
@ -4,10 +4,6 @@ Init file for the model loader.
|
|||||||
"""
|
"""
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||||
from .load_base import AnyModelLoader, LoadedModel
|
from .load_base import AnyModelLoader, LoadedModel
|
||||||
@ -19,16 +15,3 @@ for module in loaders:
|
|||||||
import_module(f"{__package__}.model_loaders.{module}")
|
import_module(f"{__package__}.model_loaders.{module}")
|
||||||
|
|
||||||
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"]
|
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"]
|
||||||
|
|
||||||
|
|
||||||
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
|
|
||||||
app_config = app_config or InvokeAIAppConfig.get_config()
|
|
||||||
logger = InvokeAILogger.get_logger(config=app_config)
|
|
||||||
return AnyModelLoader(
|
|
||||||
app_config=app_config,
|
|
||||||
logger=logger,
|
|
||||||
ram_cache=ModelCache(
|
|
||||||
logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size
|
|
||||||
),
|
|
||||||
convert_cache=ModelConvertCache(app_config.models_convert_cache_path),
|
|
||||||
)
|
|
||||||
|
@ -39,21 +39,21 @@ class LoadedModel:
|
|||||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||||
|
|
||||||
config: AnyModelConfig
|
config: AnyModelConfig
|
||||||
locker: ModelLockerBase
|
_locker: ModelLockerBase
|
||||||
|
|
||||||
def __enter__(self) -> AnyModel:
|
def __enter__(self) -> AnyModel:
|
||||||
"""Context entry."""
|
"""Context entry."""
|
||||||
self.locker.lock()
|
self._locker.lock()
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
"""Context exit."""
|
"""Context exit."""
|
||||||
self.locker.unlock()
|
self._locker.unlock()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self) -> AnyModel:
|
def model(self) -> AnyModel:
|
||||||
"""Return the model without locking it."""
|
"""Return the model without locking it."""
|
||||||
return self.locker.model
|
return self._locker.model
|
||||||
|
|
||||||
|
|
||||||
class ModelLoaderBase(ABC):
|
class ModelLoaderBase(ABC):
|
||||||
|
@ -75,7 +75,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
|
|
||||||
model_path = self._convert_if_needed(model_config, model_path, submodel_type)
|
model_path = self._convert_if_needed(model_config, model_path, submodel_type)
|
||||||
locker = self._load_if_needed(model_config, model_path, submodel_type)
|
locker = self._load_if_needed(model_config, model_path, submodel_type)
|
||||||
return LoadedModel(config=model_config, locker=locker)
|
return LoadedModel(config=model_config, _locker=locker)
|
||||||
|
|
||||||
def _get_model_path(
|
def _get_model_path(
|
||||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||||
|
@ -39,10 +39,7 @@ class ModelMerger(object):
|
|||||||
|
|
||||||
def __init__(self, installer: ModelInstallServiceBase):
|
def __init__(self, installer: ModelInstallServiceBase):
|
||||||
"""
|
"""
|
||||||
Initialize a ModelMerger object.
|
Initialize a ModelMerger object with the model installer.
|
||||||
|
|
||||||
:param store: Underlying storage manager for the running process.
|
|
||||||
:param config: InvokeAIAppConfig object (if not provided, default will be selected).
|
|
||||||
"""
|
"""
|
||||||
self._installer = installer
|
self._installer = installer
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ assert isinstance(data, CivitaiMetadata)
|
|||||||
if data.allow_commercial_use:
|
if data.allow_commercial_use:
|
||||||
print("Commercial use of this model is allowed")
|
print("Commercial use of this model is allowed")
|
||||||
"""
|
"""
|
||||||
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch
|
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase
|
||||||
from .metadata_base import (
|
from .metadata_base import (
|
||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
AnyModelRepoMetadataValidator,
|
AnyModelRepoMetadataValidator,
|
||||||
@ -31,7 +31,6 @@ from .metadata_base import (
|
|||||||
RemoteModelFile,
|
RemoteModelFile,
|
||||||
UnknownMetadataException,
|
UnknownMetadataException,
|
||||||
)
|
)
|
||||||
from .metadata_store import ModelMetadataStore
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnyModelRepoMetadata",
|
"AnyModelRepoMetadata",
|
||||||
@ -42,7 +41,7 @@ __all__ = [
|
|||||||
"HuggingFaceMetadata",
|
"HuggingFaceMetadata",
|
||||||
"HuggingFaceMetadataFetch",
|
"HuggingFaceMetadataFetch",
|
||||||
"LicenseRestrictions",
|
"LicenseRestrictions",
|
||||||
"ModelMetadataStore",
|
"ModelMetadataFetchBase",
|
||||||
"BaseMetadata",
|
"BaseMetadata",
|
||||||
"ModelMetadataWithFiles",
|
"ModelMetadataWithFiles",
|
||||||
"RemoteModelFile",
|
"RemoteModelFile",
|
||||||
|
@ -6,20 +6,40 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
|||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import curses
|
import curses
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import npyscreen
|
import npyscreen
|
||||||
from npyscreen import widget
|
from npyscreen import widget
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType
|
from invokeai.app.services.download import DownloadQueueService
|
||||||
|
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||||
|
from invokeai.app.services.model_install import ModelInstallService
|
||||||
|
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||||
|
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||||
|
from invokeai.backend.model_manager import (
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
ModelVariantType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.merge import ModelMerger
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
|
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
logger = InvokeAILogger.get_logger()
|
||||||
|
|
||||||
|
BASE_TYPES = [
|
||||||
|
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
|
||||||
|
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
|
||||||
|
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _parse_args() -> Namespace:
|
def _parse_args() -> Namespace:
|
||||||
@ -48,7 +68,7 @@ def _parse_args() -> Namespace:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base_model",
|
"--base_model",
|
||||||
type=str,
|
type=str,
|
||||||
choices=[x.value for x in BaseModelType],
|
choices=[x[0].value for x in BASE_TYPES],
|
||||||
help="The base model shared by the models to be merged",
|
help="The base model shared by the models to be merged",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -98,17 +118,17 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
super().__init__(parentApp, name)
|
super().__init__(parentApp, name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_manager(self):
|
def record_store(self):
|
||||||
return self.parentApp.model_manager
|
return self.parentApp.record_store
|
||||||
|
|
||||||
def afterEditing(self):
|
def afterEditing(self):
|
||||||
self.parentApp.setNextForm(None)
|
self.parentApp.setNextForm(None)
|
||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
window_height, window_width = curses.initscr().getmaxyx()
|
window_height, window_width = curses.initscr().getmaxyx()
|
||||||
|
|
||||||
self.model_names = self.get_model_names()
|
|
||||||
self.current_base = 0
|
self.current_base = 0
|
||||||
|
self.models = self.get_models(BASE_TYPES[self.current_base][0])
|
||||||
|
self.model_names = [x[1] for x in self.models]
|
||||||
max_width = max([len(x) for x in self.model_names])
|
max_width = max([len(x) for x in self.model_names])
|
||||||
max_width += 6
|
max_width += 6
|
||||||
horizontal_layout = max_width * 3 < window_width
|
horizontal_layout = max_width * 3 < window_width
|
||||||
@ -128,11 +148,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
self.base_select = self.add_widget_intelligent(
|
self.base_select = self.add_widget_intelligent(
|
||||||
SingleSelectColumns,
|
SingleSelectColumns,
|
||||||
values=[
|
values=[x[1] for x in BASE_TYPES],
|
||||||
"Models Built on SD-1.x",
|
|
||||||
"Models Built on SD-2.x",
|
|
||||||
"Models Built on SDXL",
|
|
||||||
],
|
|
||||||
value=[self.current_base],
|
value=[self.current_base],
|
||||||
columns=4,
|
columns=4,
|
||||||
max_height=2,
|
max_height=2,
|
||||||
@ -263,21 +279,20 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
def marshall_arguments(self) -> dict:
|
def marshall_arguments(self) -> dict:
|
||||||
model_names = self.model_names
|
model_keys = [x[0] for x in self.models]
|
||||||
models = [
|
models = [
|
||||||
model_names[self.model1.value[0]],
|
model_keys[self.model1.value[0]],
|
||||||
model_names[self.model2.value[0]],
|
model_keys[self.model2.value[0]],
|
||||||
]
|
]
|
||||||
if self.model3.value[0] > 0:
|
if self.model3.value[0] > 0:
|
||||||
models.append(model_names[self.model3.value[0] - 1])
|
models.append(model_keys[self.model3.value[0] - 1])
|
||||||
interp = "add_difference"
|
interp = "add_difference"
|
||||||
else:
|
else:
|
||||||
interp = self.interpolations[self.merge_method.value[0]]
|
interp = self.interpolations[self.merge_method.value[0]]
|
||||||
|
|
||||||
bases = ["sd-1", "sd-2", "sdxl"]
|
|
||||||
args = {
|
args = {
|
||||||
"model_names": models,
|
"model_keys": models,
|
||||||
"base_model": BaseModelType(bases[self.base_select.value[0]]),
|
"base_model": tuple(BaseModelType)[self.base_select.value[0]],
|
||||||
"alpha": self.alpha.value,
|
"alpha": self.alpha.value,
|
||||||
"interp": interp,
|
"interp": interp,
|
||||||
"force": self.force.value,
|
"force": self.force.value,
|
||||||
@ -311,18 +326,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_model_names(self, base_model: BaseModelType = BaseModelType.StableDiffusion1) -> List[str]:
|
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
|
||||||
model_names = [
|
models = [
|
||||||
info["model_name"]
|
(x.key, x.name)
|
||||||
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
for x in self.record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model)
|
||||||
if info["model_format"] == "diffusers"
|
if x.format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal")
|
||||||
]
|
]
|
||||||
return sorted(model_names)
|
return sorted(models, key=lambda x: x[1])
|
||||||
|
|
||||||
def _populate_models(self, value=None):
|
def _populate_models(self, value: List[int]):
|
||||||
bases = ["sd-1", "sd-2", "sdxl"]
|
base_model = BASE_TYPES[value[0]][0]
|
||||||
base_model = BaseModelType(bases[value[0]])
|
self.models = self.get_models(base_model)
|
||||||
self.model_names = self.get_model_names(base_model)
|
self.model_names = [x[1] for x in self.models]
|
||||||
|
|
||||||
models_plus_none = self.model_names.copy()
|
models_plus_none = self.model_names.copy()
|
||||||
models_plus_none.insert(0, "None")
|
models_plus_none.insert(0, "None")
|
||||||
@ -334,24 +349,24 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
|
|
||||||
|
|
||||||
class Mergeapp(npyscreen.NPSAppManaged):
|
class Mergeapp(npyscreen.NPSAppManaged):
|
||||||
def __init__(self, model_manager: ModelManager):
|
def __init__(self, record_store: ModelRecordServiceBase):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_manager = model_manager
|
self.record_store = record_store
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
||||||
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
|
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
|
||||||
|
|
||||||
|
|
||||||
def run_gui(args: Namespace):
|
def run_gui(args: Namespace) -> None:
|
||||||
model_manager = ModelManager(config.model_conf_path)
|
record_store: ModelRecordServiceBase = get_config_store()
|
||||||
mergeapp = Mergeapp(model_manager)
|
mergeapp = Mergeapp(record_store)
|
||||||
mergeapp.run()
|
mergeapp.run()
|
||||||
|
|
||||||
args = mergeapp.merge_arguments
|
args = mergeapp.merge_arguments
|
||||||
merger = ModelMerger(model_manager)
|
merger = get_model_merger(record_store)
|
||||||
merger.merge_diffusion_models_and_save(**args)
|
merger.merge_diffusion_models_and_save(**args)
|
||||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
merged_model_name = args["merged_model_name"]
|
||||||
|
logger.info(f'Models merged into new model: "{merged_model_name}".')
|
||||||
|
|
||||||
|
|
||||||
def run_cli(args: Namespace):
|
def run_cli(args: Namespace):
|
||||||
@ -364,20 +379,54 @@ def run_cli(args: Namespace):
|
|||||||
args.merged_model_name = "+".join(args.model_names)
|
args.merged_model_name = "+".join(args.model_names)
|
||||||
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
|
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
|
||||||
|
|
||||||
model_manager = ModelManager(config.model_conf_path)
|
record_store: ModelRecordServiceBase = get_config_store()
|
||||||
assert (
|
assert (
|
||||||
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
|
len(record_store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
|
||||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||||
|
|
||||||
merger = ModelMerger(model_manager)
|
merger = get_model_merger(record_store)
|
||||||
merger.merge_diffusion_models_and_save(**vars(args))
|
model_keys = []
|
||||||
|
for name in args.model_names:
|
||||||
|
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
|
||||||
|
model_keys.append(name)
|
||||||
|
else:
|
||||||
|
models = record_store.search_by_attr(
|
||||||
|
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
|
||||||
|
)
|
||||||
|
assert len(models) > 0, f"{name}: Unknown model"
|
||||||
|
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
|
||||||
|
model_keys.append(models[0].key)
|
||||||
|
|
||||||
|
merger.merge_diffusion_models_and_save(
|
||||||
|
alpha=args.alpha,
|
||||||
|
model_keys=model_keys,
|
||||||
|
merged_model_name=args.merged_model_name,
|
||||||
|
interp=args.interp,
|
||||||
|
force=args.force,
|
||||||
|
)
|
||||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_store() -> ModelRecordServiceSQL:
|
||||||
|
output_path = config.output_path
|
||||||
|
assert output_path is not None
|
||||||
|
image_files = DiskImageFileStorage(output_path / "images")
|
||||||
|
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
|
||||||
|
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:
|
||||||
|
installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=DownloadQueueService())
|
||||||
|
installer.start()
|
||||||
|
return ModelMerger(installer)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = _parse_args()
|
args = _parse_args()
|
||||||
if args.root_dir:
|
if args.root_dir:
|
||||||
config.parse_args(["--root", str(args.root_dir)])
|
config.parse_args(["--root", str(args.root_dir)])
|
||||||
|
else:
|
||||||
|
config.parse_args([])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if args.front_end:
|
if args.front_end:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -65,7 +66,7 @@ def mock_services() -> InvocationServices:
|
|||||||
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||||
latents=None, # type: ignore
|
latents=None, # type: ignore
|
||||||
logger=logging, # type: ignore
|
logger=logging, # type: ignore
|
||||||
model_manager=None, # type: ignore
|
model_manager=Mock(), # type: ignore
|
||||||
download_queue=None, # type: ignore
|
download_queue=None, # type: ignore
|
||||||
names=None, # type: ignore
|
names=None, # type: ignore
|
||||||
performance_statistics=InvocationStatsService(),
|
performance_statistics=InvocationStatsService(),
|
||||||
|
@ -8,6 +8,7 @@ from typing import Any
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||||
from invokeai.app.services.model_records import (
|
from invokeai.app.services.model_records import (
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
ModelRecordOrderBy,
|
ModelRecordOrderBy,
|
||||||
@ -36,7 +37,7 @@ def store(
|
|||||||
config = InvokeAIAppConfig(root=datadir)
|
config = InvokeAIAppConfig(root=datadir)
|
||||||
logger = InvokeAILogger.get_logger(config=config)
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
db = create_mock_sqlite_database(config, logger)
|
db = create_mock_sqlite_database(config, logger)
|
||||||
return ModelRecordServiceSQL(db)
|
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||||
|
|
||||||
|
|
||||||
def example_config() -> TextualInversionConfig:
|
def example_config() -> TextualInversionConfig:
|
||||||
|
@ -14,6 +14,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from invokeai.app.services.download import DownloadQueueService
|
from invokeai.app.services.download import DownloadQueueService
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||||
|
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -21,7 +22,6 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache
|
from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache
|
||||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
|
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
|
||||||
RepoCivitaiModelMetadata1,
|
RepoCivitaiModelMetadata1,
|
||||||
@ -104,7 +104,7 @@ def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordS
|
|||||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
|
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
|
||||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||||
store = ModelRecordServiceSQL(db)
|
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||||
# add five simple config records to the database
|
# add five simple config records to the database
|
||||||
raw1 = {
|
raw1 = {
|
||||||
"path": "/tmp/foo1",
|
"path": "/tmp/foo1",
|
||||||
@ -163,15 +163,14 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
|
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
|
||||||
db = mm2_record_store._db # to ensure we are sharing the same database
|
return mm2_record_store.metadata_store
|
||||||
return ModelMetadataStore(db)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||||
"""This fixtures defines a series of mock URLs for testing download and installation."""
|
"""This fixtures defines a series of mock URLs for testing download and installation."""
|
||||||
sess = TestSession()
|
sess: Session = TestSession()
|
||||||
sess.mount(
|
sess.mount(
|
||||||
"https://test.com/missing_model.safetensors",
|
"https://test.com/missing_model.safetensors",
|
||||||
TestAdapter(
|
TestAdapter(
|
||||||
@ -258,8 +257,7 @@ def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> Mo
|
|||||||
logger = InvokeAILogger.get_logger()
|
logger = InvokeAILogger.get_logger()
|
||||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||||
events = DummyEventService()
|
events = DummyEventService()
|
||||||
store = ModelRecordServiceSQL(db)
|
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||||
metadata_store = ModelMetadataStore(db)
|
|
||||||
|
|
||||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||||
download_queue.start()
|
download_queue.start()
|
||||||
@ -268,7 +266,6 @@ def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> Mo
|
|||||||
app_config=mm2_app_config,
|
app_config=mm2_app_config,
|
||||||
record_store=store,
|
record_store=store,
|
||||||
download_queue=download_queue,
|
download_queue=download_queue,
|
||||||
metadata_store=metadata_store,
|
|
||||||
event_bus=events,
|
event_bus=events,
|
||||||
session=mm2_session,
|
session=mm2_session,
|
||||||
)
|
)
|
||||||
|
@ -8,6 +8,7 @@ import pytest
|
|||||||
from pydantic.networks import HttpUrl
|
from pydantic.networks import HttpUrl
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
|
|
||||||
|
from invokeai.app.services.model_metadata import ModelMetadataStoreBase
|
||||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||||
from invokeai.backend.model_manager.metadata import (
|
from invokeai.backend.model_manager.metadata import (
|
||||||
CivitaiMetadata,
|
CivitaiMetadata,
|
||||||
@ -15,14 +16,13 @@ from invokeai.backend.model_manager.metadata import (
|
|||||||
CommercialUsage,
|
CommercialUsage,
|
||||||
HuggingFaceMetadata,
|
HuggingFaceMetadata,
|
||||||
HuggingFaceMetadataFetch,
|
HuggingFaceMetadataFetch,
|
||||||
ModelMetadataStore,
|
|
||||||
UnknownMetadataException,
|
UnknownMetadataException,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.util import select_hf_files
|
from invokeai.backend.model_manager.util import select_hf_files
|
||||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||||
|
|
||||||
|
|
||||||
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
|
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||||
tags = {"text-to-image", "diffusers"}
|
tags = {"text-to-image", "diffusers"}
|
||||||
input_metadata = HuggingFaceMetadata(
|
input_metadata = HuggingFaceMetadata(
|
||||||
name="sdxl-vae",
|
name="sdxl-vae",
|
||||||
@ -40,7 +40,7 @@ def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
|
|||||||
assert mm2_metadata_store.list_tags() == tags
|
assert mm2_metadata_store.list_tags() == tags
|
||||||
|
|
||||||
|
|
||||||
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
|
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||||
input_metadata = HuggingFaceMetadata(
|
input_metadata = HuggingFaceMetadata(
|
||||||
name="sdxl-vae",
|
name="sdxl-vae",
|
||||||
author="stabilityai",
|
author="stabilityai",
|
||||||
@ -57,7 +57,7 @@ def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
|
|||||||
assert input_metadata == output_metadata
|
assert input_metadata == output_metadata
|
||||||
|
|
||||||
|
|
||||||
def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None:
|
def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||||
metadata1 = HuggingFaceMetadata(
|
metadata1 = HuggingFaceMetadata(
|
||||||
name="sdxl-vae",
|
name="sdxl-vae",
|
||||||
author="stabilityai",
|
author="stabilityai",
|
||||||
|
Loading…
Reference in New Issue
Block a user