mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix issues identified during PR review by RyanjDick and brandonrising
- ModelMetadataStoreService is now injected into ModelRecordStoreService (these two services are really joined at the hip, and should someday be merged) - ModelRecordStoreService is now injected into ModelManagerService - Reduced timeout value for the various installer and download wait*() methods - Introduced a Mock modelmanager for testing - Removed bare print() statement with _logger in the install helper backend. - Removed unused code from model loader init file - Made `locker` a private variable in the `LoadedModel` object. - Fixed up model merge frontend (will be deprecated anyway!)
This commit is contained in:
committed by
psychedelicious
parent
f1597bd6da
commit
996eb96b4e
@ -25,6 +25,7 @@ from invokeai.app.services.model_install import (
|
||||
ModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
@ -45,7 +46,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
|
||||
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
return obj
|
||||
|
||||
|
||||
@ -54,12 +55,10 @@ def initialize_installer(
|
||||
) -> ModelInstallServiceBase:
|
||||
"""Return an initialized ModelInstallService object."""
|
||||
record_store = initialize_record_store(app_config)
|
||||
metadata_store = record_store.metadata_store
|
||||
download_queue = DownloadQueueService()
|
||||
installer = ModelInstallService(
|
||||
app_config=app_config,
|
||||
record_store=record_store,
|
||||
metadata_store=metadata_store,
|
||||
download_queue=download_queue,
|
||||
event_bus=event_bus,
|
||||
)
|
||||
@ -287,14 +286,14 @@ class InstallHelper(object):
|
||||
model_name=model_name,
|
||||
)
|
||||
if len(matches) > 1:
|
||||
print(
|
||||
f"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate."
|
||||
self._logger.error(
|
||||
"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate"
|
||||
)
|
||||
elif not matches:
|
||||
print(f"{model_to_remove}: unknown model")
|
||||
self._logger.error(f"{model_to_remove}: unknown model")
|
||||
else:
|
||||
for m in matches:
|
||||
print(f"Deleting {m.type}:{m.name}")
|
||||
self._logger.info(f"Deleting {m.type}:{m.name}")
|
||||
installer.delete(m.key)
|
||||
|
||||
installer.wait_for_installs()
|
||||
|
@ -4,10 +4,6 @@ Init file for the model loader.
|
||||
"""
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||
from .load_base import AnyModelLoader, LoadedModel
|
||||
@ -19,16 +15,3 @@ for module in loaders:
|
||||
import_module(f"{__package__}.model_loaders.{module}")
|
||||
|
||||
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"]
|
||||
|
||||
|
||||
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
|
||||
app_config = app_config or InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
return AnyModelLoader(
|
||||
app_config=app_config,
|
||||
logger=logger,
|
||||
ram_cache=ModelCache(
|
||||
logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size
|
||||
),
|
||||
convert_cache=ModelConvertCache(app_config.models_convert_cache_path),
|
||||
)
|
||||
|
@ -39,21 +39,21 @@ class LoadedModel:
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
config: AnyModelConfig
|
||||
locker: ModelLockerBase
|
||||
_locker: ModelLockerBase
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
self.locker.lock()
|
||||
self._locker.lock()
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Context exit."""
|
||||
self.locker.unlock()
|
||||
self._locker.unlock()
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without locking it."""
|
||||
return self.locker.model
|
||||
return self._locker.model
|
||||
|
||||
|
||||
class ModelLoaderBase(ABC):
|
||||
|
@ -75,7 +75,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
|
||||
model_path = self._convert_if_needed(model_config, model_path, submodel_type)
|
||||
locker = self._load_if_needed(model_config, model_path, submodel_type)
|
||||
return LoadedModel(config=model_config, locker=locker)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
|
@ -39,10 +39,7 @@ class ModelMerger(object):
|
||||
|
||||
def __init__(self, installer: ModelInstallServiceBase):
|
||||
"""
|
||||
Initialize a ModelMerger object.
|
||||
|
||||
:param store: Underlying storage manager for the running process.
|
||||
:param config: InvokeAIAppConfig object (if not provided, default will be selected).
|
||||
Initialize a ModelMerger object with the model installer.
|
||||
"""
|
||||
self._installer = installer
|
||||
|
||||
|
@ -18,7 +18,7 @@ assert isinstance(data, CivitaiMetadata)
|
||||
if data.allow_commercial_use:
|
||||
print("Commercial use of this model is allowed")
|
||||
"""
|
||||
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch
|
||||
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase
|
||||
from .metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
AnyModelRepoMetadataValidator,
|
||||
@ -31,7 +31,6 @@ from .metadata_base import (
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from .metadata_store import ModelMetadataStore
|
||||
|
||||
__all__ = [
|
||||
"AnyModelRepoMetadata",
|
||||
@ -42,7 +41,7 @@ __all__ = [
|
||||
"HuggingFaceMetadata",
|
||||
"HuggingFaceMetadataFetch",
|
||||
"LicenseRestrictions",
|
||||
"ModelMetadataStore",
|
||||
"ModelMetadataFetchBase",
|
||||
"BaseMetadata",
|
||||
"ModelMetadataWithFiles",
|
||||
"RemoteModelFile",
|
||||
|
Reference in New Issue
Block a user