mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
parallel processing working on single-GPU, not tested on multi
This commit is contained in:
parent
cef51ad80d
commit
9df0980c46
@ -23,7 +23,6 @@ from ..services.image_records.image_records_sqlite import SqliteImageRecordStora
|
||||
from ..services.images.images_default import ImageService
|
||||
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
@ -102,7 +101,6 @@ class ApiDependencies:
|
||||
events=events,
|
||||
)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
session_processor = DefaultSessionProcessor()
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
@ -125,7 +123,6 @@ class ApiDependencies:
|
||||
model_manager=model_manager,
|
||||
download_queue=download_queue_service,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
session_processor=session_processor,
|
||||
session_queue=session_queue,
|
||||
urls=urls,
|
||||
|
@ -749,7 +749,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
||||
if masked_latents is not None:
|
||||
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
|
@ -24,7 +24,7 @@ DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_RAM_CACHE = 10.0
|
||||
DEFAULT_CONVERT_CACHE = 20.0
|
||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||
DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
|
||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||
@ -169,6 +169,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
|
||||
# DEVICE
|
||||
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
|
||||
devices: Optional[list[DEVICE]] = Field(default=None, description="List of execution devices; will override default device selected.")
|
||||
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
|
||||
|
||||
# GENERATION
|
||||
@ -178,6 +179,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
||||
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
||||
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
||||
max_threads: int = Field(default=4, description="Maximum number of session queue execution threads.")
|
||||
|
||||
# NODES
|
||||
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
|
||||
|
@ -24,7 +24,6 @@ if TYPE_CHECKING:
|
||||
from .image_records.image_records_base import ImageRecordStorageBase
|
||||
from .images.images_base import ImageServiceABC
|
||||
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .model_images.model_images_base import ModelImageFileStorageBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
@ -53,7 +52,6 @@ class InvocationServices:
|
||||
model_images: "ModelImageFileStorageBase",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
session_queue: "SessionQueueBase",
|
||||
session_processor: "SessionProcessorBase",
|
||||
invocation_cache: "InvocationCacheBase",
|
||||
@ -77,7 +75,6 @@ class InvocationServices:
|
||||
self.model_images = model_images
|
||||
self.model_manager = model_manager
|
||||
self.download_queue = download_queue
|
||||
self.performance_statistics = performance_statistics
|
||||
self.session_queue = session_queue
|
||||
self.session_processor = session_processor
|
||||
self.invocation_cache = invocation_cache
|
||||
|
@ -1,8 +1,6 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
@ -68,7 +66,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_devices: Optional[Set[torch.device]] = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
@ -78,6 +75,13 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
logger = InvokeAILogger.get_logger(cls.__name__)
|
||||
logger.setLevel(app_config.log_level.upper())
|
||||
|
||||
execution_devices = (
|
||||
None
|
||||
if app_config.devices is None
|
||||
else None
|
||||
if "auto" in app_config.devices
|
||||
else {torch.device(x) for x in app_config.devices}
|
||||
)
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram,
|
||||
logger=logger,
|
||||
|
@ -1,8 +1,9 @@
|
||||
import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from queue import Queue
|
||||
from threading import BoundedSemaphore, Lock, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional
|
||||
from typing import Optional, Set
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
@ -10,6 +11,7 @@ from fastapi_events.typing import Event as FastAPIEvent
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||
@ -23,7 +25,8 @@ from .session_processor_common import SessionProcessorStatus
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def start(self, invoker: Invoker, polling_interval: int = 1) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._queue_items: Set[int] = set()
|
||||
self._sessions_to_cancel: Set[int] = set()
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
|
||||
self._resume_event = ThreadEvent()
|
||||
@ -33,10 +36,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||
|
||||
self._thread_limit = self._invoker.services.model_manager.load.gpu_count
|
||||
self._thread_limit = 1
|
||||
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
|
||||
self._polling_interval = polling_interval
|
||||
|
||||
self._worker_thread_count = self._invoker.services.configuration.max_threads
|
||||
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
|
||||
self._process_lock = Lock()
|
||||
|
||||
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
|
||||
# the profiler will create a new profile for each session.
|
||||
self._profiler = (
|
||||
@ -49,6 +56,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
else None
|
||||
)
|
||||
|
||||
# main session processor loop - single thread
|
||||
self._thread = Thread(
|
||||
name="session_processor",
|
||||
target=self._process,
|
||||
@ -61,6 +69,16 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
# Session processor workers - multithreaded
|
||||
self._invoker.services.logger.debug(f"Starting {self._worker_thread_count} session processing threads.")
|
||||
for _i in range(0, self._worker_thread_count):
|
||||
worker = Thread(
|
||||
name="session_worker",
|
||||
target=self._process_next_session,
|
||||
daemon=True,
|
||||
)
|
||||
worker.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self._stop_event.set()
|
||||
|
||||
@ -70,18 +88,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
if (
|
||||
event_name == "session_canceled"
|
||||
and self._queue_item
|
||||
and self._queue_item.item_id == event[1]["data"]["queue_item_id"]
|
||||
):
|
||||
if event_name == "session_canceled" and event[1]["data"]["queue_item_id"] in self._queue_items:
|
||||
self._sessions_to_cancel.add(event[1]["data"]["queue_item_id"])
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
elif (
|
||||
event_name == "queue_cleared"
|
||||
and self._queue_item
|
||||
and self._queue_item.queue_id == event[1]["data"]["queue_id"]
|
||||
):
|
||||
elif event_name == "queue_cleared" and event[1]["data"]["queue_id"] in self._queue_items:
|
||||
self._sessions_to_cancel.add(event[1]["data"]["queue_item_id"])
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
elif event_name == "batch_enqueued":
|
||||
@ -100,7 +112,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def get_status(self) -> SessionProcessorStatus:
|
||||
return SessionProcessorStatus(
|
||||
is_started=self._resume_event.is_set(),
|
||||
is_processing=self._queue_item is not None,
|
||||
is_processing=len(self._queue_items) > 0,
|
||||
)
|
||||
|
||||
def _process(
|
||||
@ -109,7 +121,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event: ThreadEvent,
|
||||
resume_event: ThreadEvent,
|
||||
cancel_event: ThreadEvent,
|
||||
):
|
||||
) -> None:
|
||||
# Outermost processor try block; any unhandled exception is a fatal processor error
|
||||
try:
|
||||
self._thread_semaphore.acquire()
|
||||
@ -119,55 +131,138 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
|
||||
while not stop_event.is_set():
|
||||
poll_now_event.clear()
|
||||
# Middle processor try block; any unhandled exception is a non-fatal processor error
|
||||
try:
|
||||
# If we are paused, wait for resume event
|
||||
resume_event.wait()
|
||||
|
||||
# Get the next session to process
|
||||
self._queue_item = self._invoker.services.session_queue.dequeue()
|
||||
session = self._invoker.services.session_queue.dequeue()
|
||||
|
||||
if self._queue_item is None:
|
||||
if session is None:
|
||||
# The queue was empty, wait for next polling interval or event to try again
|
||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
self._queue_items.add(session.item_id)
|
||||
self._session_worker_queue.put(session)
|
||||
self._invoker.services.logger.debug(f"Executing queue item {session.item_id}")
|
||||
cancel_event.clear()
|
||||
except Exception:
|
||||
# Fatal error in processor, log and pass - we're done here
|
||||
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}")
|
||||
pass
|
||||
finally:
|
||||
stop_event.clear()
|
||||
poll_now_event.clear()
|
||||
self._queue_items.clear()
|
||||
self._thread_semaphore.release()
|
||||
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
def _process_next_session(self) -> None:
|
||||
profiler = (
|
||||
Profiler(
|
||||
logger=self._invoker.services.logger,
|
||||
output_dir=self._invoker.services.configuration.profiles_path,
|
||||
prefix=self._invoker.services.configuration.profile_prefix,
|
||||
)
|
||||
if self._invoker.services.configuration.profile_graphs
|
||||
else None
|
||||
)
|
||||
stats_service = InvocationStatsService()
|
||||
stats_service.start(self._invoker)
|
||||
|
||||
while True:
|
||||
# Outer try block. Any error here is a fatal processor error
|
||||
try:
|
||||
session = self._session_worker_queue.get()
|
||||
if self._cancel_event.is_set():
|
||||
if session.item_id in self._sessions_to_cancel:
|
||||
print("DEBUG: CANCEL")
|
||||
continue
|
||||
|
||||
if profiler is not None:
|
||||
profiler.start(profile_id=session.session_id)
|
||||
|
||||
# Prepare invocations and take the first
|
||||
self._invocation = self._queue_item.session.next()
|
||||
with self._process_lock:
|
||||
invocation = session.session.next()
|
||||
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
while invocation is not None:
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
self._resume_event.wait()
|
||||
|
||||
self._process_next_invocation(session, invocation, stats_service)
|
||||
|
||||
# The session is complete if all invocations are complete or there was an error
|
||||
if session.session.is_complete():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=session.batch_id,
|
||||
queue_item_id=session.item_id,
|
||||
queue_id=session.queue_id,
|
||||
graph_execution_state_id=session.session.id,
|
||||
)
|
||||
# Log stats
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
stats_service.log_stats(session.session.id)
|
||||
stats_service.reset_stats()
|
||||
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
stats_service.dump_stats(
|
||||
graph_execution_state_id=session.session.id, output_path=stats_path
|
||||
)
|
||||
self._queue_items.remove(session.item_id)
|
||||
invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
with self._process_lock:
|
||||
invocation = session.session.next()
|
||||
|
||||
except Exception:
|
||||
# Non-fatal error in processor
|
||||
self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{traceback.format_exc()}")
|
||||
|
||||
# Cancel the queue item
|
||||
if session is not None:
|
||||
self._invoker.services.session_queue.cancel_queue_item(
|
||||
session.item_id, error=traceback.format_exc()
|
||||
)
|
||||
finally:
|
||||
self._session_worker_queue.task_done()
|
||||
|
||||
def _process_next_invocation(
|
||||
self,
|
||||
session: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
stats_service: InvocationStatsService,
|
||||
) -> None:
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
source_invocation_id = session.session.prepared_source_mapping[invocation.id]
|
||||
|
||||
self._invoker.services.logger.debug(f"Executing invocation {session.session.id}:{source_invocation_id}")
|
||||
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
queue_batch_id=session.batch_id,
|
||||
queue_item_id=session.item_id,
|
||||
queue_id=session.queue_id,
|
||||
graph_execution_state_id=session.session_id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
with self._invoker.services.performance_statistics.collect_stats(
|
||||
self._invocation, self._queue_item.session.id
|
||||
):
|
||||
# Build invocation context (the node-facing API)
|
||||
data = InvocationContextData(
|
||||
invocation=self._invocation,
|
||||
invocation=invocation,
|
||||
source_invocation_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
queue_item=session,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
@ -176,20 +271,20 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
outputs = self._invocation.invoke_internal(
|
||||
context=context, services=self._invoker.services
|
||||
)
|
||||
# title = invocation.UIConfig.title
|
||||
with stats_service.collect_stats(invocation, session.session.id):
|
||||
outputs = invocation.invoke_internal(context=context, services=self._invoker.services)
|
||||
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
session.session.complete(invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
queue_batch_id=session.batch_id,
|
||||
queue_item_id=session.item_id,
|
||||
queue_id=session.queue_id,
|
||||
graph_execution_state_id=session.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
@ -215,78 +310,20 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||
session.session.set_node_error(invocation.id, error)
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
f"Error while invoking session {session.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
queue_batch_id=session.session_id,
|
||||
queue_item_id=session.item_id,
|
||||
queue_id=session.queue_id,
|
||||
graph_execution_state_id=session.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Set the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
self._invocation = self._queue_item.session.next()
|
||||
else:
|
||||
# The queue was empty, wait for next polling interval or event to try again
|
||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
except Exception:
|
||||
# Non-fatal error in processor
|
||||
self._invoker.services.logger.error(
|
||||
f"Non-fatal error in session processor:\n{traceback.format_exc()}"
|
||||
)
|
||||
# Cancel the queue item
|
||||
if self._queue_item is not None:
|
||||
self._invoker.services.session_queue.cancel_queue_item(
|
||||
self._queue_item.item_id, error=traceback.format_exc()
|
||||
)
|
||||
# Reset the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
# Immediately poll for next queue item
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
except Exception:
|
||||
# Fatal error in processor, log and pass - we're done here
|
||||
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}")
|
||||
pass
|
||||
finally:
|
||||
stop_event.clear()
|
||||
poll_now_event.clear()
|
||||
self._queue_item = None
|
||||
self._thread_semaphore.release()
|
||||
|
@ -27,6 +27,7 @@ from threading import BoundedSemaphore, Lock
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||
@ -50,6 +51,26 @@ GIG = 1073741824
|
||||
MB = 2**20
|
||||
|
||||
|
||||
# GPU device can only be used by one thread at a time.
|
||||
# The refcount indicates the number of models stored
|
||||
# in it.
|
||||
class GPUDeviceStatus(BaseModel):
|
||||
"""Track of which threads are using the GPU(s) on this system."""
|
||||
|
||||
device: torch.device
|
||||
thread_id: int = 0
|
||||
refcount: int = 0
|
||||
|
||||
class Config:
|
||||
"""Configure the base model."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Allow to be added to a set."""
|
||||
return hash(str(torch.device))
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Implementation of ModelCacheBase."""
|
||||
|
||||
@ -79,7 +100,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
|
||||
self._execution_devices: Set[GPUDeviceStatus] = self._get_execution_devices(execution_devices)
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._lock = threading.Lock()
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
@ -91,9 +112,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
self._lock = Lock()
|
||||
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
|
||||
self._busy_execution_devices: Set[torch.device] = set()
|
||||
|
||||
self.logger.info(f"Using rendering device(s) {[self._device_name(x) for x in self._execution_devices]}")
|
||||
self.logger.info(
|
||||
f"Using rendering device(s): {', '.join(sorted([str(x.device) for x in self._execution_devices]))}"
|
||||
)
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
@ -108,22 +130,34 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
@property
|
||||
def execution_devices(self) -> Set[torch.device]:
|
||||
"""Return the set of available execution devices."""
|
||||
return self._execution_devices
|
||||
return {x.device for x in self._execution_devices}
|
||||
|
||||
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
|
||||
"""Acquire and return an execution device (e.g. "cuda" for VRAM)."""
|
||||
with self._lock:
|
||||
current_thread = threading.current_thread().ident
|
||||
assert current_thread is not None
|
||||
|
||||
# first try to assign a device that is already executing on this thread
|
||||
if claimed_devices := [x for x in self._execution_devices if x.thread_id == current_thread]:
|
||||
claimed_devices[0].refcount += 1
|
||||
return claimed_devices[0].device
|
||||
else:
|
||||
# this thread is not currently using any gpu. Wait for a free one
|
||||
self._free_execution_device.acquire(timeout=timeout)
|
||||
free_devices = self.execution_devices - self._busy_execution_devices
|
||||
chosen_device = list(free_devices)[0]
|
||||
self._busy_execution_devices.add(chosen_device)
|
||||
return chosen_device
|
||||
unclaimed_devices = [x for x in self._execution_devices if x.refcount == 0]
|
||||
unclaimed_devices[0].thread_id = current_thread
|
||||
unclaimed_devices[0].refcount += 1
|
||||
return unclaimed_devices[0].device
|
||||
|
||||
def release_execution_device(self, device: torch.device) -> None:
|
||||
"""Mark this execution device as unused."""
|
||||
with self._lock:
|
||||
current_thread = threading.current_thread().ident
|
||||
for x in self._execution_devices:
|
||||
if x.thread_id == current_thread and x.device == device:
|
||||
x.refcount -= 1
|
||||
if x.refcount == 0:
|
||||
x.thread_id = 0
|
||||
self._free_execution_device.release()
|
||||
self._busy_execution_devices.remove(device)
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
@ -174,7 +208,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if key in self._cached_models:
|
||||
return
|
||||
self.make_room(size)
|
||||
cache_record = CacheRecord(key, model, size)
|
||||
cache_record = CacheRecord(key, model=model, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
@ -361,14 +395,15 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
raise torch.cuda.OutOfMemoryError
|
||||
|
||||
@staticmethod
|
||||
def _get_execution_devices() -> Set[torch.device]:
|
||||
def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[GPUDeviceStatus]:
|
||||
if not devices:
|
||||
default_device = choose_torch_device()
|
||||
if default_device != torch.device("cuda"):
|
||||
return {default_device}
|
||||
|
||||
# we get here if the default device is cuda, and return each of the
|
||||
# cuda devices.
|
||||
return {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
||||
devices = {default_device}
|
||||
else:
|
||||
# we get here if the default device is cuda, and return each of the cuda devices.
|
||||
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
||||
return {GPUDeviceStatus(device=x) for x in devices}
|
||||
|
||||
@staticmethod
|
||||
def _device_name(device: torch.device) -> str:
|
||||
|
@ -17,7 +17,6 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.images.images_default import ImageService
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401
|
||||
@ -48,7 +47,6 @@ def mock_services() -> InvocationServices:
|
||||
model_manager=None, # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
session_processor=None, # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
|
Loading…
Reference in New Issue
Block a user