mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implement session-level reservation of gpus
This commit is contained in:
parent
eca29c41d0
commit
3d69372785
@ -100,7 +100,8 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
|
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
|
||||||
convert_cache: Maximum size of on-disk converted models cache (GB).
|
convert_cache: Maximum size of on-disk converted models cache (GB).
|
||||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `mps`
|
||||||
|
devices: List of execution devices; will override default device selected.
|
||||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
|
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
|
||||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||||
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
||||||
@ -108,6 +109,7 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
||||||
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
||||||
max_queue_size: Maximum number of items in the session queue.
|
max_queue_size: Maximum number of items in the session queue.
|
||||||
|
max_threads: Maximum number of session queue execution threads.
|
||||||
allow_nodes: List of nodes to allow. Omit to allow all.
|
allow_nodes: List of nodes to allow. Omit to allow all.
|
||||||
deny_nodes: List of nodes to deny. Omit to deny none.
|
deny_nodes: List of nodes to deny. Omit to deny none.
|
||||||
node_cache_size: How many cached nodes to keep in memory.
|
node_cache_size: How many cached nodes to keep in memory.
|
||||||
|
@ -181,47 +181,51 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
if profiler is not None:
|
if profiler is not None:
|
||||||
profiler.start(profile_id=session.session_id)
|
profiler.start(profile_id=session.session_id)
|
||||||
|
|
||||||
# Prepare invocations and take the first
|
# reserve a GPU for this session - may block
|
||||||
with self._process_lock:
|
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device() as gpu:
|
||||||
invocation = session.session.next()
|
print(f"DEBUG: session {session.item_id} has reserved gpu {gpu}")
|
||||||
|
|
||||||
# Loop over invocations until the session is complete or canceled
|
# Prepare invocations and take the first
|
||||||
while invocation is not None:
|
with self._process_lock:
|
||||||
if self._stop_event.is_set():
|
invocation = session.session.next()
|
||||||
break
|
|
||||||
self._resume_event.wait()
|
|
||||||
|
|
||||||
self._process_next_invocation(session, invocation, stats_service)
|
# Loop over invocations until the session is complete or canceled
|
||||||
|
while invocation is not None:
|
||||||
|
if self._stop_event.is_set():
|
||||||
|
break
|
||||||
|
self._resume_event.wait()
|
||||||
|
|
||||||
# The session is complete if all invocations are complete or there was an error
|
self._process_next_invocation(session, invocation, stats_service)
|
||||||
if session.session.is_complete():
|
|
||||||
# Send complete event
|
|
||||||
self._invoker.services.events.emit_graph_execution_complete(
|
|
||||||
queue_batch_id=session.batch_id,
|
|
||||||
queue_item_id=session.item_id,
|
|
||||||
queue_id=session.queue_id,
|
|
||||||
graph_execution_state_id=session.session.id,
|
|
||||||
)
|
|
||||||
# Log stats
|
|
||||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
|
||||||
# we don't care about that - suppress the error.
|
|
||||||
with suppress(GESStatsNotFoundError):
|
|
||||||
stats_service.log_stats(session.session.id)
|
|
||||||
stats_service.reset_stats()
|
|
||||||
|
|
||||||
# If we are profiling, stop the profiler and dump the profile & stats
|
# The session is complete if all invocations are complete or there was an error
|
||||||
if self._profiler:
|
if session.session.is_complete():
|
||||||
profile_path = self._profiler.stop()
|
# Send complete event
|
||||||
stats_path = profile_path.with_suffix(".json")
|
self._invoker.services.events.emit_graph_execution_complete(
|
||||||
stats_service.dump_stats(
|
queue_batch_id=session.batch_id,
|
||||||
graph_execution_state_id=session.session.id, output_path=stats_path
|
queue_item_id=session.item_id,
|
||||||
|
queue_id=session.queue_id,
|
||||||
|
graph_execution_state_id=session.session.id,
|
||||||
)
|
)
|
||||||
self._queue_items.remove(session.item_id)
|
# Log stats
|
||||||
invocation = None
|
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||||
else:
|
# we don't care about that - suppress the error.
|
||||||
# Prepare the next invocation
|
with suppress(GESStatsNotFoundError):
|
||||||
with self._process_lock:
|
stats_service.log_stats(session.session.id)
|
||||||
invocation = session.session.next()
|
stats_service.reset_stats()
|
||||||
|
|
||||||
|
# If we are profiling, stop the profiler and dump the profile & stats
|
||||||
|
if self._profiler:
|
||||||
|
profile_path = self._profiler.stop()
|
||||||
|
stats_path = profile_path.with_suffix(".json")
|
||||||
|
stats_service.dump_stats(
|
||||||
|
graph_execution_state_id=session.session.id, output_path=stats_path
|
||||||
|
)
|
||||||
|
self._queue_items.remove(session.item_id)
|
||||||
|
invocation = None
|
||||||
|
else:
|
||||||
|
# Prepare the next invocation
|
||||||
|
with self._process_lock:
|
||||||
|
invocation = session.session.next()
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# Non-fatal error in processor
|
# Non-fatal error in processor
|
||||||
|
@ -8,9 +8,10 @@ model will be cleared and (re)loaded from disk when next needed.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import Dict, Generic, Optional, Set, TypeVar
|
from typing import Dict, Generator, Generic, Optional, Set, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -93,20 +94,23 @@ class ModelCacheBase(ABC, Generic[T]):
|
|||||||
"""Return the set of available execution devices."""
|
"""Return the set of available execution devices."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
|
def reserve_execution_device(self, timeout: int = 0) -> Generator[torch.device, None, None]:
|
||||||
"""
|
"""Reserve an execution device (GPU) under the current thread id."""
|
||||||
Pick the next available execution device.
|
|
||||||
|
|
||||||
If all devices are currently engaged (locked), then
|
|
||||||
block until timeout seconds have passed and raise a
|
|
||||||
TimeoutError if no devices are available.
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def release_execution_device(self, device: torch.device) -> None:
|
def get_execution_device(self) -> torch.device:
|
||||||
"""Release a previously-acquired execution device."""
|
"""
|
||||||
|
Return an execution device that has been reserved for current thread.
|
||||||
|
|
||||||
|
Note that reservations are done using the current thread's TID.
|
||||||
|
It would be better to do this using the session ID, but that involves
|
||||||
|
too many detailed changes to model manager calls.
|
||||||
|
|
||||||
|
May generate a ValueError if no GPU has been reserved.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -21,13 +21,12 @@ context. Use like this:
|
|||||||
import gc
|
import gc
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from contextlib import suppress
|
from contextlib import contextmanager, suppress
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from threading import BoundedSemaphore, Lock
|
from threading import BoundedSemaphore
|
||||||
from typing import Dict, List, Optional, Set
|
from typing import Dict, Generator, List, Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||||
@ -51,26 +50,6 @@ GIG = 1073741824
|
|||||||
MB = 2**20
|
MB = 2**20
|
||||||
|
|
||||||
|
|
||||||
# GPU device can only be used by one thread at a time.
|
|
||||||
# The refcount indicates the number of models stored
|
|
||||||
# in it.
|
|
||||||
class GPUDeviceStatus(BaseModel):
|
|
||||||
"""Track of which threads are using the GPU(s) on this system."""
|
|
||||||
|
|
||||||
device: torch.device
|
|
||||||
thread_id: int = 0
|
|
||||||
refcount: int = 0
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configure the base model."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
"""Allow to be added to a set."""
|
|
||||||
return hash(str(torch.device))
|
|
||||||
|
|
||||||
|
|
||||||
class ModelCache(ModelCacheBase[AnyModel]):
|
class ModelCache(ModelCacheBase[AnyModel]):
|
||||||
"""Implementation of ModelCacheBase."""
|
"""Implementation of ModelCacheBase."""
|
||||||
|
|
||||||
@ -100,9 +79,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
"""
|
"""
|
||||||
self._precision: torch.dtype = precision
|
self._precision: torch.dtype = precision
|
||||||
self._max_cache_size: float = max_cache_size
|
self._max_cache_size: float = max_cache_size
|
||||||
self._execution_devices: Set[GPUDeviceStatus] = self._get_execution_devices(execution_devices)
|
|
||||||
self._storage_device: torch.device = storage_device
|
self._storage_device: torch.device = storage_device
|
||||||
self._lock = threading.Lock()
|
self._ram_lock = threading.Lock()
|
||||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||||
self._log_memory_usage = log_memory_usage
|
self._log_memory_usage = log_memory_usage
|
||||||
self._stats: Optional[CacheStats] = None
|
self._stats: Optional[CacheStats] = None
|
||||||
@ -110,11 +88,15 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||||
self._cache_stack: List[str] = []
|
self._cache_stack: List[str] = []
|
||||||
|
|
||||||
self._lock = Lock()
|
# device to thread id
|
||||||
|
self._device_lock = threading.Lock()
|
||||||
|
self._execution_devices: Dict[torch.device, int] = {
|
||||||
|
x: 0 for x in execution_devices or self._get_execution_devices()
|
||||||
|
}
|
||||||
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
|
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Using rendering device(s): {', '.join(sorted([str(x.device) for x in self._execution_devices]))}"
|
f"Using rendering device(s): {', '.join(sorted([str(x) for x in self._execution_devices.keys()]))}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -130,34 +112,61 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
@property
|
@property
|
||||||
def execution_devices(self) -> Set[torch.device]:
|
def execution_devices(self) -> Set[torch.device]:
|
||||||
"""Return the set of available execution devices."""
|
"""Return the set of available execution devices."""
|
||||||
return {x.device for x in self._execution_devices}
|
devices = self._execution_devices.keys()
|
||||||
|
return set(devices)
|
||||||
|
|
||||||
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
|
def get_execution_device(self) -> torch.device:
|
||||||
"""Acquire and return an execution device (e.g. "cuda" for VRAM)."""
|
"""
|
||||||
|
Return an execution device that has been reserved for current thread.
|
||||||
|
|
||||||
|
Note that reservations are done using the current thread's TID.
|
||||||
|
It would be better to do this using the session ID, but that involves
|
||||||
|
too many detailed changes to model manager calls.
|
||||||
|
|
||||||
|
May generate a ValueError if no GPU has been reserved.
|
||||||
|
"""
|
||||||
current_thread = threading.current_thread().ident
|
current_thread = threading.current_thread().ident
|
||||||
assert current_thread is not None
|
assert current_thread is not None
|
||||||
|
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
|
||||||
|
if not assigned:
|
||||||
|
raise ValueError("No GPU has been reserved for the use of thread {current_thread}")
|
||||||
|
return assigned[0]
|
||||||
|
|
||||||
# first try to assign a device that is already executing on this thread
|
@contextmanager
|
||||||
if claimed_devices := [x for x in self._execution_devices if x.thread_id == current_thread]:
|
def reserve_execution_device(self, timeout: Optional[int] = None) -> Generator[torch.device, None, None]:
|
||||||
claimed_devices[0].refcount += 1
|
"""Reserve an execution device (e.g. GPU) for exclusive use by a generation thread.
|
||||||
return claimed_devices[0].device
|
|
||||||
else:
|
Note that the reservation is done using the current thread's TID.
|
||||||
# this thread is not currently using any gpu. Wait for a free one
|
It would be better to do this using the session ID, but that involves
|
||||||
|
too many detailed changes to model manager calls.
|
||||||
|
"""
|
||||||
|
device = None
|
||||||
|
with self._device_lock:
|
||||||
|
current_thread = threading.current_thread().ident
|
||||||
|
assert current_thread is not None
|
||||||
|
|
||||||
|
# look for a device that has already been assigned to this thread
|
||||||
|
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
|
||||||
|
if assigned:
|
||||||
|
device = assigned[0]
|
||||||
|
|
||||||
|
# no device already assigned. Get one.
|
||||||
|
if device is None:
|
||||||
self._free_execution_device.acquire(timeout=timeout)
|
self._free_execution_device.acquire(timeout=timeout)
|
||||||
unclaimed_devices = [x for x in self._execution_devices if x.refcount == 0]
|
with self._device_lock:
|
||||||
unclaimed_devices[0].thread_id = current_thread
|
free_device = [x for x, tid in self._execution_devices.items() if tid == 0]
|
||||||
unclaimed_devices[0].refcount += 1
|
print(f"DEBUG: execution devices = {self._execution_devices}")
|
||||||
return unclaimed_devices[0].device
|
self._execution_devices[free_device[0]] = current_thread
|
||||||
|
device = free_device[0]
|
||||||
|
|
||||||
def release_execution_device(self, device: torch.device) -> None:
|
# we are outside the lock region now
|
||||||
"""Mark this execution device as unused."""
|
try:
|
||||||
current_thread = threading.current_thread().ident
|
yield device
|
||||||
for x in self._execution_devices:
|
finally:
|
||||||
if x.thread_id == current_thread and x.device == device:
|
with self._device_lock:
|
||||||
x.refcount -= 1
|
self._execution_devices[device] = 0
|
||||||
if x.refcount == 0:
|
self._free_execution_device.release()
|
||||||
x.thread_id = 0
|
torch.cuda.empty_cache()
|
||||||
self._free_execution_device.release()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_cache_size(self) -> float:
|
def max_cache_size(self) -> float:
|
||||||
@ -203,7 +212,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store model under key and optional submodel_type."""
|
"""Store model under key and optional submodel_type."""
|
||||||
with self._lock:
|
with self._ram_lock:
|
||||||
key = self._make_cache_key(key, submodel_type)
|
key = self._make_cache_key(key, submodel_type)
|
||||||
if key in self._cached_models:
|
if key in self._cached_models:
|
||||||
return
|
return
|
||||||
@ -228,7 +237,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
|
|
||||||
This may raise an IndexError if the model is not in the cache.
|
This may raise an IndexError if the model is not in the cache.
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._ram_lock:
|
||||||
key = self._make_cache_key(key, submodel_type)
|
key = self._make_cache_key(key, submodel_type)
|
||||||
if key in self._cached_models:
|
if key in self._cached_models:
|
||||||
if self.stats:
|
if self.stats:
|
||||||
@ -395,7 +404,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
raise torch.cuda.OutOfMemoryError
|
raise torch.cuda.OutOfMemoryError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[GPUDeviceStatus]:
|
def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[torch.device]:
|
||||||
if not devices:
|
if not devices:
|
||||||
default_device = choose_torch_device()
|
default_device = choose_torch_device()
|
||||||
if default_device != torch.device("cuda"):
|
if default_device != torch.device("cuda"):
|
||||||
@ -403,7 +412,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
else:
|
else:
|
||||||
# we get here if the default device is cuda, and return each of the cuda devices.
|
# we get here if the default device is cuda, and return each of the cuda devices.
|
||||||
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
||||||
return {GPUDeviceStatus(device=x) for x in devices}
|
return devices
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _device_name(device: torch.device) -> str:
|
def _device_name(device: torch.device) -> str:
|
||||||
|
@ -56,8 +56,8 @@ class ModelLocker(ModelLockerBase):
|
|||||||
self._cache_entry.lock()
|
self._cache_entry.lock()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# We wait for a gpu to be free - may raise a TimeoutError
|
# We wait for a gpu to be free - may raise a ValueError
|
||||||
self._execution_device = self._cache.acquire_execution_device(MAX_GPU_WAIT)
|
self._execution_device = self._cache.get_execution_device()
|
||||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
|
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
|
||||||
model_in_gpu = copy.deepcopy(self._cache_entry.model)
|
model_in_gpu = copy.deepcopy(self._cache_entry.model)
|
||||||
if hasattr(model_in_gpu, "to"):
|
if hasattr(model_in_gpu, "to"):
|
||||||
@ -77,14 +77,5 @@ class ModelLocker(ModelLockerBase):
|
|||||||
"""Call upon exit from context."""
|
"""Call upon exit from context."""
|
||||||
if not hasattr(self.model, "to"):
|
if not hasattr(self.model, "to"):
|
||||||
return
|
return
|
||||||
|
|
||||||
self._cache_entry.unlock()
|
self._cache_entry.unlock()
|
||||||
if self._execution_device:
|
|
||||||
self._cache.release_execution_device(self._execution_device)
|
|
||||||
|
|
||||||
try:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.mps.empty_cache()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
|
@ -15,7 +15,15 @@ MPS_DEVICE = torch.device("mps")
|
|||||||
|
|
||||||
|
|
||||||
def choose_torch_device() -> torch.device:
|
def choose_torch_device() -> torch.device:
|
||||||
"""Convenience routine for guessing which GPU device to run model on"""
|
"""Convenience routine for guessing which GPU device to run model on."""
|
||||||
|
# """Temporarily modified to use the model manager's get_execution_device()"""
|
||||||
|
# try:
|
||||||
|
# from invokeai.app.api.dependencies import ApiDependencies
|
||||||
|
# model_manager = ApiDependencies.invoker.services.model_manager
|
||||||
|
# device = model_manager.load.ram_cache.acquire_execution_device()
|
||||||
|
# print(f'DEBUG choose_torch_device returning {device}')
|
||||||
|
# return device
|
||||||
|
# except Exception:
|
||||||
config = get_config()
|
config = get_config()
|
||||||
if config.device == "auto":
|
if config.device == "auto":
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
Loading…
Reference in New Issue
Block a user