diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 9b771c5159..c2718b5b2e 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -543,7 +543,7 @@ class ModelInstallService(ModelInstallServiceBase): self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None ) -> str: info = info or ModelProbe.probe(model_path, config) - key = self._create_key() + key = info.key or self._create_key() model_path = model_path.absolute() if model_path.is_relative_to(self.app_config.models_path): diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index c25aa6fb47..938e14adcb 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,5 +1,7 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team +import torch + from abc import ABC, abstractmethod from typing import Optional @@ -32,9 +34,10 @@ class ModelManagerServiceBase(ABC): def build_model_manager( cls, app_config: InvokeAIAppConfig, - db: SqliteDatabase, + model_record_service: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, events: EventServiceBase, + execution_device: torch.device, ) -> Self: """ Construct the model manager service instance. diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index d029f9e033..2276111586 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,6 +1,8 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" +import torch + from typing import Optional from typing_extensions import Self @@ -9,6 +11,7 @@ from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry +from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger from ..config import InvokeAIAppConfig @@ -119,6 +122,7 @@ class ModelManagerService(ModelManagerServiceBase): model_record_service: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, events: EventServiceBase, + execution_device: torch.device = choose_torch_device(), ) -> Self: """ Construct the model manager service instance. @@ -129,7 +133,10 @@ class ModelManagerService(ModelManagerServiceBase): logger.setLevel(app_config.log_level.upper()) ram_cache = ModelCache( - max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size, logger=logger + max_cache_size=app_config.ram_cache_size, + max_vram_cache_size=app_config.vram_cache_size, + logger=logger, + execution_device=execution_device, ) convert_cache = ModelConvertCache( cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size