make object_serializer._new_name() thread-safe; add max_threads config

This commit is contained in:
Lincoln Stein 2024-04-16 15:23:49 -04:00
parent bd833900a3
commit fb9b7fb63a
4 changed files with 21 additions and 6 deletions

View File

@ -110,7 +110,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue. max_queue_size: Maximum number of items in the session queue.
max_threads: Maximum number of session queue execution threads. max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
allow_nodes: List of nodes to allow. Omit to allow all. allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none. deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory. node_cache_size: How many cached nodes to keep in memory.
@ -182,7 +182,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
max_threads: int = Field(default=4, description="Maximum number of session queue execution threads.") max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
# NODES # NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.") allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")

View File

@ -1,5 +1,5 @@
import threading
import tempfile import tempfile
import threading
import typing import typing
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
@ -72,7 +72,9 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
def _new_name(self) -> str: def _new_name(self) -> str:
tid = threading.current_thread().ident tid = threading.current_thread().ident
return f"{self._obj_class_name}_{tid}_{uuid_string()}" # Add tid to the object name because uuid4 not thread-safe on windows
# See https://stackoverflow.com/questions/2759644/python-multiprocessing-doesnt-play-nicely-with-uuid-uuid4
return f"{self._obj_class_name}_{tid}-{uuid_string()}"
def _tempdir_cleanup(self) -> None: def _tempdir_cleanup(self) -> None:
"""Calls `cleanup` on the temporary directory, if it exists.""" """Calls `cleanup` on the temporary directory, if it exists."""

View File

@ -16,6 +16,7 @@ from invokeai.app.services.session_processor.session_processor_common import Can
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler from invokeai.app.util.profiler import Profiler
from invokeai.backend.util.devices import TorchDevice
from ..invoker import Invoker from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase from .session_processor_base import SessionProcessorBase
@ -40,7 +41,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._thread_semaphore = BoundedSemaphore(self._thread_limit) self._thread_semaphore = BoundedSemaphore(self._thread_limit)
self._polling_interval = polling_interval self._polling_interval = polling_interval
self._worker_thread_count = self._invoker.services.configuration.max_threads self._worker_thread_count = self._invoker.services.configuration.max_threads or len(
TorchDevice.execution_devices()
)
self._session_worker_queue: Queue[SessionQueueItem] = Queue() self._session_worker_queue: Queue[SessionQueueItem] = Queue()
self._process_lock = Lock() self._process_lock = Lock()

View File

@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Dict, Literal, Optional, Union """Torch Device class provides torch device selection services."""
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
import torch import torch
from deprecated import deprecated from deprecated import deprecated
@ -69,6 +71,14 @@ class TorchDevice:
device = CPU_DEVICE device = CPU_DEVICE
return cls.normalize(device) return cls.normalize(device)
@classmethod
def execution_devices(cls) -> List[torch.device]:
"""Return a list of torch.devices that can be used for accelerated inference."""
if cls._model_cache:
return cls._model_cache.execution_devices
else:
return [cls.choose_torch_device]
@classmethod @classmethod
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype: def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
"""Return the precision to use for accelerated inference.""" """Return the precision to use for accelerated inference."""