mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
device selection calls go through TorchDevice
This commit is contained in:
parent
371f5bc782
commit
99558de178
@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
@ -75,13 +74,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
logger = InvokeAILogger.get_logger(cls.__name__)
|
||||
logger.setLevel(app_config.log_level.upper())
|
||||
|
||||
execution_devices = (
|
||||
None
|
||||
if app_config.devices is None
|
||||
else None
|
||||
if "auto" in app_config.devices
|
||||
else {torch.device(x) for x in app_config.devices}
|
||||
)
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram,
|
||||
logger=logger,
|
||||
|
@ -86,9 +86,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
# device to thread id
|
||||
self._device_lock = threading.Lock()
|
||||
self._execution_devices: Dict[torch.device, int] = {
|
||||
x: 0 for x in TorchDevice.execution_devices()
|
||||
}
|
||||
self._execution_devices: Dict[torch.device, int] = {x: 0 for x in TorchDevice.execution_devices()}
|
||||
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
|
||||
|
||||
self.logger.info(
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Torch Device class provides torch device selection services."""
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, Set
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from deprecated import deprecated
|
||||
@ -141,4 +141,3 @@ class TorchDevice:
|
||||
else:
|
||||
devices = {torch.device("cpu")}
|
||||
return devices
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user