device selection calls go through TorchDevice

This commit is contained in:
Lincoln Stein 2024-04-16 16:26:58 -04:00
parent 371f5bc782
commit 99558de178
3 changed files with 2 additions and 13 deletions

View File

@ -1,7 +1,6 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase.""" """Implementation of ModelManagerServiceBase."""
import torch
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
@ -75,13 +74,6 @@ class ModelManagerService(ModelManagerServiceBase):
logger = InvokeAILogger.get_logger(cls.__name__) logger = InvokeAILogger.get_logger(cls.__name__)
logger.setLevel(app_config.log_level.upper()) logger.setLevel(app_config.log_level.upper())
execution_devices = (
None
if app_config.devices is None
else None
if "auto" in app_config.devices
else {torch.device(x) for x in app_config.devices}
)
ram_cache = ModelCache( ram_cache = ModelCache(
max_cache_size=app_config.ram, max_cache_size=app_config.ram,
logger=logger, logger=logger,

View File

@ -86,9 +86,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
# device to thread id # device to thread id
self._device_lock = threading.Lock() self._device_lock = threading.Lock()
self._execution_devices: Dict[torch.device, int] = { self._execution_devices: Dict[torch.device, int] = {x: 0 for x in TorchDevice.execution_devices()}
x: 0 for x in TorchDevice.execution_devices()
}
self._free_execution_device = BoundedSemaphore(len(self._execution_devices)) self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
self.logger.info( self.logger.info(

View File

@ -1,6 +1,6 @@
"""Torch Device class provides torch device selection services.""" """Torch Device class provides torch device selection services."""
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, Set from typing import TYPE_CHECKING, Dict, Literal, Optional, Set, Union
import torch import torch
from deprecated import deprecated from deprecated import deprecated
@ -141,4 +141,3 @@ class TorchDevice:
else: else:
devices = {torch.device("cpu")} devices = {torch.device("cpu")}
return devices return devices