mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make object_serializer._new_name() thread-safe; add max_threads config
This commit is contained in:
parent
bd833900a3
commit
fb9b7fb63a
@ -110,7 +110,7 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
||||||
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
||||||
max_queue_size: Maximum number of items in the session queue.
|
max_queue_size: Maximum number of items in the session queue.
|
||||||
max_threads: Maximum number of session queue execution threads.
|
max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
|
||||||
allow_nodes: List of nodes to allow. Omit to allow all.
|
allow_nodes: List of nodes to allow. Omit to allow all.
|
||||||
deny_nodes: List of nodes to deny. Omit to deny none.
|
deny_nodes: List of nodes to deny. Omit to deny none.
|
||||||
node_cache_size: How many cached nodes to keep in memory.
|
node_cache_size: How many cached nodes to keep in memory.
|
||||||
@ -182,7 +182,7 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
||||||
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
||||||
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
||||||
max_threads: int = Field(default=4, description="Maximum number of session queue execution threads.")
|
max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
|
||||||
|
|
||||||
# NODES
|
# NODES
|
||||||
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
|
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import threading
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import threading
|
||||||
import typing
|
import typing
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -72,7 +72,9 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
|||||||
|
|
||||||
def _new_name(self) -> str:
|
def _new_name(self) -> str:
|
||||||
tid = threading.current_thread().ident
|
tid = threading.current_thread().ident
|
||||||
return f"{self._obj_class_name}_{tid}_{uuid_string()}"
|
# Add tid to the object name because uuid4 not thread-safe on windows
|
||||||
|
# See https://stackoverflow.com/questions/2759644/python-multiprocessing-doesnt-play-nicely-with-uuid-uuid4
|
||||||
|
return f"{self._obj_class_name}_{tid}-{uuid_string()}"
|
||||||
|
|
||||||
def _tempdir_cleanup(self) -> None:
|
def _tempdir_cleanup(self) -> None:
|
||||||
"""Calls `cleanup` on the temporary directory, if it exists."""
|
"""Calls `cleanup` on the temporary directory, if it exists."""
|
||||||
|
@ -16,6 +16,7 @@ from invokeai.app.services.session_processor.session_processor_common import Can
|
|||||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||||
from invokeai.app.util.profiler import Profiler
|
from invokeai.app.util.profiler import Profiler
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from ..invoker import Invoker
|
from ..invoker import Invoker
|
||||||
from .session_processor_base import SessionProcessorBase
|
from .session_processor_base import SessionProcessorBase
|
||||||
@ -40,7 +41,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
|
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
|
||||||
self._polling_interval = polling_interval
|
self._polling_interval = polling_interval
|
||||||
|
|
||||||
self._worker_thread_count = self._invoker.services.configuration.max_threads
|
self._worker_thread_count = self._invoker.services.configuration.max_threads or len(
|
||||||
|
TorchDevice.execution_devices()
|
||||||
|
)
|
||||||
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
|
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
|
||||||
self._process_lock = Lock()
|
self._process_lock = Lock()
|
||||||
|
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Union
|
"""Torch Device class provides torch device selection services."""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from deprecated import deprecated
|
from deprecated import deprecated
|
||||||
@ -69,6 +71,14 @@ class TorchDevice:
|
|||||||
device = CPU_DEVICE
|
device = CPU_DEVICE
|
||||||
return cls.normalize(device)
|
return cls.normalize(device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execution_devices(cls) -> List[torch.device]:
|
||||||
|
"""Return a list of torch.devices that can be used for accelerated inference."""
|
||||||
|
if cls._model_cache:
|
||||||
|
return cls._model_cache.execution_devices
|
||||||
|
else:
|
||||||
|
return [cls.choose_torch_device]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
||||||
"""Return the precision to use for accelerated inference."""
|
"""Return the precision to use for accelerated inference."""
|
||||||
|
Loading…
Reference in New Issue
Block a user