mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
device selection calls go through TorchDevice
This commit is contained in:
parent
371f5bc782
commit
99558de178
@ -1,7 +1,6 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
"""Implementation of ModelManagerServiceBase."""
|
"""Implementation of ModelManagerServiceBase."""
|
||||||
|
|
||||||
import torch
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
@ -75,13 +74,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
logger = InvokeAILogger.get_logger(cls.__name__)
|
logger = InvokeAILogger.get_logger(cls.__name__)
|
||||||
logger.setLevel(app_config.log_level.upper())
|
logger.setLevel(app_config.log_level.upper())
|
||||||
|
|
||||||
execution_devices = (
|
|
||||||
None
|
|
||||||
if app_config.devices is None
|
|
||||||
else None
|
|
||||||
if "auto" in app_config.devices
|
|
||||||
else {torch.device(x) for x in app_config.devices}
|
|
||||||
)
|
|
||||||
ram_cache = ModelCache(
|
ram_cache = ModelCache(
|
||||||
max_cache_size=app_config.ram,
|
max_cache_size=app_config.ram,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
@ -86,9 +86,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
|
|
||||||
# device to thread id
|
# device to thread id
|
||||||
self._device_lock = threading.Lock()
|
self._device_lock = threading.Lock()
|
||||||
self._execution_devices: Dict[torch.device, int] = {
|
self._execution_devices: Dict[torch.device, int] = {x: 0 for x in TorchDevice.execution_devices()}
|
||||||
x: 0 for x in TorchDevice.execution_devices()
|
|
||||||
}
|
|
||||||
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
|
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Torch Device class provides torch device selection services."""
|
"""Torch Device class provides torch device selection services."""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, Set
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Set, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from deprecated import deprecated
|
from deprecated import deprecated
|
||||||
@ -141,4 +141,3 @@ class TorchDevice:
|
|||||||
else:
|
else:
|
||||||
devices = {torch.device("cpu")}
|
devices = {torch.device("cpu")}
|
||||||
return devices
|
return devices
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user