From 99558de17863196cd9e8d874d046a4f917a3c22f Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 16 Apr 2024 16:26:58 -0400 Subject: [PATCH] device selection calls go through TorchDevice --- .../app/services/model_manager/model_manager_default.py | 8 -------- .../model_manager/load/model_cache/model_cache_default.py | 4 +--- invokeai/backend/util/devices.py | 3 +-- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 902501c1f9..ccb68f783b 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,7 +1,6 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" -import torch from typing_extensions import Self from invokeai.app.services.invoker import Invoker @@ -75,13 +74,6 @@ class ModelManagerService(ModelManagerServiceBase): logger = InvokeAILogger.get_logger(cls.__name__) logger.setLevel(app_config.log_level.upper()) - execution_devices = ( - None - if app_config.devices is None - else None - if "auto" in app_config.devices - else {torch.device(x) for x in app_config.devices} - ) ram_cache = ModelCache( max_cache_size=app_config.ram, logger=logger, diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 9eab509e64..551412d66a 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -86,9 +86,7 @@ class ModelCache(ModelCacheBase[AnyModel]): # device to thread id self._device_lock = threading.Lock() - self._execution_devices: Dict[torch.device, int] = { - x: 0 for x in TorchDevice.execution_devices() - } + self._execution_devices: Dict[torch.device, int] = {x: 0 for x in TorchDevice.execution_devices()} self._free_execution_device = BoundedSemaphore(len(self._execution_devices)) self.logger.info( diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index d1c432a53f..745c128099 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,6 +1,6 @@ """Torch Device class provides torch device selection services.""" -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, Set +from typing import TYPE_CHECKING, Dict, Literal, Optional, Set, Union import torch from deprecated import deprecated @@ -141,4 +141,3 @@ class TorchDevice: else: devices = {torch.device("cpu")} return devices -