make object_serializer._new_name() thread-safe; add max_threads config

This commit is contained in:
Lincoln Stein 2024-04-16 15:23:49 -04:00
parent bd833900a3
commit fb9b7fb63a
4 changed files with 21 additions and 6 deletions

View File

@ -110,7 +110,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
max_threads: Maximum number of session queue execution threads.
max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory.
@ -182,7 +182,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
max_threads: int = Field(default=4, description="Maximum number of session queue execution threads.")
max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
# NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")

View File

@ -1,5 +1,5 @@
import threading
import tempfile
import threading
import typing
from dataclasses import dataclass
from pathlib import Path
@ -72,7 +72,9 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
def _new_name(self) -> str:
tid = threading.current_thread().ident
return f"{self._obj_class_name}_{tid}_{uuid_string()}"
# Add tid to the object name because uuid4 not thread-safe on windows
# See https://stackoverflow.com/questions/2759644/python-multiprocessing-doesnt-play-nicely-with-uuid-uuid4
return f"{self._obj_class_name}_{tid}-{uuid_string()}"
def _tempdir_cleanup(self) -> None:
"""Calls `cleanup` on the temporary directory, if it exists."""

View File

@ -16,6 +16,7 @@ from invokeai.app.services.session_processor.session_processor_common import Can
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler
from invokeai.backend.util.devices import TorchDevice
from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase
@ -40,7 +41,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
self._polling_interval = polling_interval
self._worker_thread_count = self._invoker.services.configuration.max_threads
self._worker_thread_count = self._invoker.services.configuration.max_threads or len(
TorchDevice.execution_devices()
)
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
self._process_lock = Lock()

View File

@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Dict, Literal, Optional, Union
"""Torch Device class provides torch device selection services."""
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
import torch
from deprecated import deprecated
@ -69,6 +71,14 @@ class TorchDevice:
device = CPU_DEVICE
return cls.normalize(device)
@classmethod
def execution_devices(cls) -> List[torch.device]:
"""Return a list of torch.devices that can be used for accelerated inference."""
if cls._model_cache:
return cls._model_cache.execution_devices
else:
return [cls.choose_torch_device]
@classmethod
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
"""Return the precision to use for accelerated inference."""