mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make object_serializer._new_name() thread-safe; add max_threads config
This commit is contained in:
parent
bd833900a3
commit
fb9b7fb63a
@ -110,7 +110,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
||||
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
||||
max_queue_size: Maximum number of items in the session queue.
|
||||
max_threads: Maximum number of session queue execution threads.
|
||||
max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
|
||||
allow_nodes: List of nodes to allow. Omit to allow all.
|
||||
deny_nodes: List of nodes to deny. Omit to deny none.
|
||||
node_cache_size: How many cached nodes to keep in memory.
|
||||
@ -182,7 +182,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
||||
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
||||
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
||||
max_threads: int = Field(default=4, description="Maximum number of session queue execution threads.")
|
||||
max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
|
||||
|
||||
# NODES
|
||||
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
|
||||
|
@ -1,5 +1,5 @@
|
||||
import threading
|
||||
import tempfile
|
||||
import threading
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@ -72,7 +72,9 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
|
||||
def _new_name(self) -> str:
|
||||
tid = threading.current_thread().ident
|
||||
return f"{self._obj_class_name}_{tid}_{uuid_string()}"
|
||||
# Add tid to the object name because uuid4 not thread-safe on windows
|
||||
# See https://stackoverflow.com/questions/2759644/python-multiprocessing-doesnt-play-nicely-with-uuid-uuid4
|
||||
return f"{self._obj_class_name}_{tid}-{uuid_string()}"
|
||||
|
||||
def _tempdir_cleanup(self) -> None:
|
||||
"""Calls `cleanup` on the temporary directory, if it exists."""
|
||||
|
@ -16,6 +16,7 @@ from invokeai.app.services.session_processor.session_processor_common import Can
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import SessionProcessorBase
|
||||
@ -40,7 +41,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
|
||||
self._polling_interval = polling_interval
|
||||
|
||||
self._worker_thread_count = self._invoker.services.configuration.max_threads
|
||||
self._worker_thread_count = self._invoker.services.configuration.max_threads or len(
|
||||
TorchDevice.execution_devices()
|
||||
)
|
||||
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
|
||||
self._process_lock = Lock()
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Union
|
||||
"""Torch Device class provides torch device selection services."""
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from deprecated import deprecated
|
||||
@ -69,6 +71,14 @@ class TorchDevice:
|
||||
device = CPU_DEVICE
|
||||
return cls.normalize(device)
|
||||
|
||||
@classmethod
|
||||
def execution_devices(cls) -> List[torch.device]:
|
||||
"""Return a list of torch.devices that can be used for accelerated inference."""
|
||||
if cls._model_cache:
|
||||
return cls._model_cache.execution_devices
|
||||
else:
|
||||
return [cls.choose_torch_device]
|
||||
|
||||
@classmethod
|
||||
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
||||
"""Return the precision to use for accelerated inference."""
|
||||
|
Loading…
Reference in New Issue
Block a user