diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 9a6c7416f6..7332b35c08 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -23,7 +23,6 @@ from ..services.image_records.image_records_sqlite import SqliteImageRecordStora from ..services.images.images_default import ImageService from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from ..services.invocation_services import InvocationServices -from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker from ..services.model_images.model_images_default import ModelImageFileStorageDisk from ..services.model_manager.model_manager_default import ModelManagerService @@ -102,7 +101,6 @@ class ApiDependencies: events=events, ) names = SimpleNameService() - performance_statistics = InvocationStatsService() session_processor = DefaultSessionProcessor() session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() @@ -125,7 +123,6 @@ class ApiDependencies: model_manager=model_manager, download_queue=download_queue_service, names=names, - performance_statistics=performance_statistics, session_processor=session_processor, session_queue=session_queue, urls=urls, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index bc79efdeba..7845cbba03 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -749,7 +749,6 @@ class DenoiseLatentsInvocation(BaseInvocation): mask = mask.to(device=unet.device, dtype=unet.dtype) if masked_latents is not None: masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype) - scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 2e9578a56a..258cd58e8d 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -24,7 +24,7 @@ DB_FILE = Path("invokeai.db") LEGACY_INIT_FILE = Path("invokeai.init") DEFAULT_RAM_CACHE = 10.0 DEFAULT_CONVERT_CACHE = 20.0 -DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"] +DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "mps"] PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"] ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] @@ -169,6 +169,7 @@ class InvokeAIAppConfig(BaseSettings): # DEVICE device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.") + devices: Optional[list[DEVICE]] = Field(default=None, description="List of execution devices; will override default device selected.") precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.") # GENERATION @@ -178,6 +179,7 @@ class InvokeAIAppConfig(BaseSettings): force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") + max_threads: int = Field(default=4, description="Maximum number of session queue execution threads.") # NODES allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.") diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index f4fce6098f..0e1ec123ca 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -24,7 +24,6 @@ if TYPE_CHECKING: from .image_records.image_records_base import ImageRecordStorageBase from .images.images_base import ImageServiceABC from .invocation_cache.invocation_cache_base import InvocationCacheBase - from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .model_images.model_images_base import ModelImageFileStorageBase from .model_manager.model_manager_base import ModelManagerServiceBase from .names.names_base import NameServiceBase @@ -53,7 +52,6 @@ class InvocationServices: model_images: "ModelImageFileStorageBase", model_manager: "ModelManagerServiceBase", download_queue: "DownloadQueueServiceBase", - performance_statistics: "InvocationStatsServiceBase", session_queue: "SessionQueueBase", session_processor: "SessionProcessorBase", invocation_cache: "InvocationCacheBase", @@ -77,7 +75,6 @@ class InvocationServices: self.model_images = model_images self.model_manager = model_manager self.download_queue = download_queue - self.performance_statistics = performance_statistics self.session_queue = session_queue self.session_processor = session_processor self.invocation_cache = invocation_cache diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index e2da8e2712..241259c803 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,8 +1,6 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" -from typing import Optional, Set - import torch from typing_extensions import Self @@ -68,7 +66,6 @@ class ModelManagerService(ModelManagerServiceBase): model_record_service: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, events: EventServiceBase, - execution_devices: Optional[Set[torch.device]] = None, ) -> Self: """ Construct the model manager service instance. @@ -78,6 +75,13 @@ class ModelManagerService(ModelManagerServiceBase): logger = InvokeAILogger.get_logger(cls.__name__) logger.setLevel(app_config.log_level.upper()) + execution_devices = ( + None + if app_config.devices is None + else None + if "auto" in app_config.devices + else {torch.device(x) for x in app_config.devices} + ) ram_cache = ModelCache( max_cache_size=app_config.ram, logger=logger, diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index d6791fbd57..3088d99c5d 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -1,8 +1,9 @@ import traceback from contextlib import suppress -from threading import BoundedSemaphore, Thread +from queue import Queue +from threading import BoundedSemaphore, Lock, Thread from threading import Event as ThreadEvent -from typing import Optional +from typing import Optional, Set from fastapi_events.handlers.local import local_handler from fastapi_events.typing import Event as FastAPIEvent @@ -10,6 +11,7 @@ from fastapi_events.typing import Event as FastAPIEvent from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError +from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context @@ -23,7 +25,8 @@ from .session_processor_common import SessionProcessorStatus class DefaultSessionProcessor(SessionProcessorBase): def start(self, invoker: Invoker, polling_interval: int = 1) -> None: self._invoker: Invoker = invoker - self._queue_item: Optional[SessionQueueItem] = None + self._queue_items: Set[int] = set() + self._sessions_to_cancel: Set[int] = set() self._invocation: Optional[BaseInvocation] = None self._resume_event = ThreadEvent() @@ -33,10 +36,14 @@ class DefaultSessionProcessor(SessionProcessorBase): local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) - self._thread_limit = self._invoker.services.model_manager.load.gpu_count + self._thread_limit = 1 self._thread_semaphore = BoundedSemaphore(self._thread_limit) self._polling_interval = polling_interval + self._worker_thread_count = self._invoker.services.configuration.max_threads + self._session_worker_queue: Queue[SessionQueueItem] = Queue() + self._process_lock = Lock() + # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, # the profiler will create a new profile for each session. self._profiler = ( @@ -49,6 +56,7 @@ class DefaultSessionProcessor(SessionProcessorBase): else None ) + # main session processor loop - single thread self._thread = Thread( name="session_processor", target=self._process, @@ -61,6 +69,16 @@ class DefaultSessionProcessor(SessionProcessorBase): ) self._thread.start() + # Session processor workers - multithreaded + self._invoker.services.logger.debug(f"Starting {self._worker_thread_count} session processing threads.") + for _i in range(0, self._worker_thread_count): + worker = Thread( + name="session_worker", + target=self._process_next_session, + daemon=True, + ) + worker.start() + def stop(self, *args, **kwargs) -> None: self._stop_event.set() @@ -70,18 +88,12 @@ class DefaultSessionProcessor(SessionProcessorBase): async def _on_queue_event(self, event: FastAPIEvent) -> None: event_name = event[1]["event"] - if ( - event_name == "session_canceled" - and self._queue_item - and self._queue_item.item_id == event[1]["data"]["queue_item_id"] - ): + if event_name == "session_canceled" and event[1]["data"]["queue_item_id"] in self._queue_items: + self._sessions_to_cancel.add(event[1]["data"]["queue_item_id"]) self._cancel_event.set() self._poll_now() - elif ( - event_name == "queue_cleared" - and self._queue_item - and self._queue_item.queue_id == event[1]["data"]["queue_id"] - ): + elif event_name == "queue_cleared" and event[1]["data"]["queue_id"] in self._queue_items: + self._sessions_to_cancel.add(event[1]["data"]["queue_item_id"]) self._cancel_event.set() self._poll_now() elif event_name == "batch_enqueued": @@ -100,7 +112,7 @@ class DefaultSessionProcessor(SessionProcessorBase): def get_status(self) -> SessionProcessorStatus: return SessionProcessorStatus( is_started=self._resume_event.is_set(), - is_processing=self._queue_item is not None, + is_processing=len(self._queue_items) > 0, ) def _process( @@ -109,7 +121,7 @@ class DefaultSessionProcessor(SessionProcessorBase): poll_now_event: ThreadEvent, resume_event: ThreadEvent, cancel_event: ThreadEvent, - ): + ) -> None: # Outermost processor try block; any unhandled exception is a fatal processor error try: self._thread_semaphore.acquire() @@ -119,168 +131,21 @@ class DefaultSessionProcessor(SessionProcessorBase): while not stop_event.is_set(): poll_now_event.clear() - # Middle processor try block; any unhandled exception is a non-fatal processor error - try: - # If we are paused, wait for resume event - resume_event.wait() + resume_event.wait() - # Get the next session to process - self._queue_item = self._invoker.services.session_queue.dequeue() + # Get the next session to process + session = self._invoker.services.session_queue.dequeue() - if self._queue_item is None: - # The queue was empty, wait for next polling interval or event to try again - self._invoker.services.logger.debug("Waiting for next polling interval or event") - poll_now_event.wait(self._polling_interval) - continue - - self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") - cancel_event.clear() - - # If profiling is enabled, start the profiler - if self._profiler is not None: - self._profiler.start(profile_id=self._queue_item.session_id) - - # Prepare invocations and take the first - self._invocation = self._queue_item.session.next() - - # Loop over invocations until the session is complete or canceled - while self._invocation is not None and not cancel_event.is_set(): - # get the source node id to provide to clients (the prepared node id is not as useful) - source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id] - - # Send starting event - self._invoker.services.events.emit_invocation_started( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session_id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - ) - - # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph - try: - with self._invoker.services.performance_statistics.collect_stats( - self._invocation, self._queue_item.session.id - ): - # Build invocation context (the node-facing API) - data = InvocationContextData( - invocation=self._invocation, - source_invocation_id=source_invocation_id, - queue_item=self._queue_item, - ) - context = build_invocation_context( - data=data, - services=self._invoker.services, - cancel_event=self._cancel_event, - ) - - # Invoke the node - outputs = self._invocation.invoke_internal( - context=context, services=self._invoker.services - ) - - # Save outputs and history - self._queue_item.session.complete(self._invocation.id, outputs) - - # Send complete event - self._invoker.services.events.emit_invocation_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - result=outputs.model_dump(), - ) - - except KeyboardInterrupt: - # TODO(MM2): Create an event for this - pass - - except CanceledException: - # When the user cancels the graph, we first set the cancel event. The event is checked - # between invocations, in this loop. Some invocations are long-running, and we need to - # be able to cancel them mid-execution. - # - # For example, denoising is a long-running invocation with many steps. A step callback - # is executed after each step. This step callback checks if the canceled event is set, - # then raises a CanceledException to stop execution immediately. - # - # When we get a CanceledException, we don't need to do anything - just pass and let the - # loop go to its next iteration, and the cancel event will be handled correctly. - pass - - except Exception as e: - error = traceback.format_exc() - - # Save error - self._queue_item.session.set_node_error(self._invocation.id, error) - self._invoker.services.logger.error( - f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}" - ) - self._invoker.services.logger.error(error) - - # Send error event - self._invoker.services.events.emit_invocation_error( - queue_batch_id=self._queue_item.session_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - error_type=e.__class__.__name__, - error=error, - ) - pass - - # The session is complete if the all invocations are complete or there was an error - if self._queue_item.session.is_complete() or cancel_event.is_set(): - # Send complete event - self._invoker.services.events.emit_graph_execution_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - ) - # If we are profiling, stop the profiler and dump the profile & stats - if self._profiler: - profile_path = self._profiler.stop() - stats_path = profile_path.with_suffix(".json") - self._invoker.services.performance_statistics.dump_stats( - graph_execution_state_id=self._queue_item.session.id, output_path=stats_path - ) - # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor - # we don't care about that - suppress the error. - with suppress(GESStatsNotFoundError): - self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id) - self._invoker.services.performance_statistics.reset_stats() - - # Set the invocation to None to prepare for the next session - self._invocation = None - else: - # Prepare the next invocation - self._invocation = self._queue_item.session.next() - else: - # The queue was empty, wait for next polling interval or event to try again - self._invoker.services.logger.debug("Waiting for next polling interval or event") - poll_now_event.wait(self._polling_interval) - continue - except Exception: - # Non-fatal error in processor - self._invoker.services.logger.error( - f"Non-fatal error in session processor:\n{traceback.format_exc()}" - ) - # Cancel the queue item - if self._queue_item is not None: - self._invoker.services.session_queue.cancel_queue_item( - self._queue_item.item_id, error=traceback.format_exc() - ) - # Reset the invocation to None to prepare for the next session - self._invocation = None - # Immediately poll for next queue item + if session is None: + # The queue was empty, wait for next polling interval or event to try again + self._invoker.services.logger.debug("Waiting for next polling interval or event") poll_now_event.wait(self._polling_interval) continue + + self._queue_items.add(session.item_id) + self._session_worker_queue.put(session) + self._invoker.services.logger.debug(f"Executing queue item {session.item_id}") + cancel_event.clear() except Exception: # Fatal error in processor, log and pass - we're done here self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}") @@ -288,5 +153,177 @@ class DefaultSessionProcessor(SessionProcessorBase): finally: stop_event.clear() poll_now_event.clear() - self._queue_item = None + self._queue_items.clear() self._thread_semaphore.release() + + def _process_next_session(self) -> None: + profiler = ( + Profiler( + logger=self._invoker.services.logger, + output_dir=self._invoker.services.configuration.profiles_path, + prefix=self._invoker.services.configuration.profile_prefix, + ) + if self._invoker.services.configuration.profile_graphs + else None + ) + stats_service = InvocationStatsService() + stats_service.start(self._invoker) + + while True: + # Outer try block. Any error here is a fatal processor error + try: + session = self._session_worker_queue.get() + if self._cancel_event.is_set(): + if session.item_id in self._sessions_to_cancel: + print("DEBUG: CANCEL") + continue + + if profiler is not None: + profiler.start(profile_id=session.session_id) + + # Prepare invocations and take the first + with self._process_lock: + invocation = session.session.next() + + # Loop over invocations until the session is complete or canceled + while invocation is not None: + if self._stop_event.is_set(): + break + self._resume_event.wait() + + self._process_next_invocation(session, invocation, stats_service) + + # The session is complete if all invocations are complete or there was an error + if session.session.is_complete(): + # Send complete event + self._invoker.services.events.emit_graph_execution_complete( + queue_batch_id=session.batch_id, + queue_item_id=session.item_id, + queue_id=session.queue_id, + graph_execution_state_id=session.session.id, + ) + # Log stats + # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor + # we don't care about that - suppress the error. + with suppress(GESStatsNotFoundError): + stats_service.log_stats(session.session.id) + stats_service.reset_stats() + + # If we are profiling, stop the profiler and dump the profile & stats + if self._profiler: + profile_path = self._profiler.stop() + stats_path = profile_path.with_suffix(".json") + stats_service.dump_stats( + graph_execution_state_id=session.session.id, output_path=stats_path + ) + self._queue_items.remove(session.item_id) + invocation = None + else: + # Prepare the next invocation + with self._process_lock: + invocation = session.session.next() + + except Exception: + # Non-fatal error in processor + self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{traceback.format_exc()}") + + # Cancel the queue item + if session is not None: + self._invoker.services.session_queue.cancel_queue_item( + session.item_id, error=traceback.format_exc() + ) + finally: + self._session_worker_queue.task_done() + + def _process_next_invocation( + self, + session: SessionQueueItem, + invocation: BaseInvocation, + stats_service: InvocationStatsService, + ) -> None: + # get the source node id to provide to clients (the prepared node id is not as useful) + source_invocation_id = session.session.prepared_source_mapping[invocation.id] + + self._invoker.services.logger.debug(f"Executing invocation {session.session.id}:{source_invocation_id}") + + # Send starting event + self._invoker.services.events.emit_invocation_started( + queue_batch_id=session.batch_id, + queue_item_id=session.item_id, + queue_id=session.queue_id, + graph_execution_state_id=session.session_id, + node=invocation.model_dump(), + source_node_id=source_invocation_id, + ) + + # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph + try: + # Build invocation context (the node-facing API) + data = InvocationContextData( + invocation=invocation, + source_invocation_id=source_invocation_id, + queue_item=session, + ) + context = build_invocation_context( + data=data, + services=self._invoker.services, + cancel_event=self._cancel_event, + ) + + # Invoke the node + # title = invocation.UIConfig.title + with stats_service.collect_stats(invocation, session.session.id): + outputs = invocation.invoke_internal(context=context, services=self._invoker.services) + + # Save outputs and history + session.session.complete(invocation.id, outputs) + + # Send complete event + self._invoker.services.events.emit_invocation_complete( + queue_batch_id=session.batch_id, + queue_item_id=session.item_id, + queue_id=session.queue_id, + graph_execution_state_id=session.session.id, + node=invocation.model_dump(), + source_node_id=source_invocation_id, + result=outputs.model_dump(), + ) + + except KeyboardInterrupt: + # TODO(MM2): Create an event for this + pass + + except CanceledException: + # When the user cancels the graph, we first set the cancel event. The event is checked + # between invocations, in this loop. Some invocations are long-running, and we need to + # be able to cancel them mid-execution. + # + # For example, denoising is a long-running invocation with many steps. A step callback + # is executed after each step. This step callback checks if the canceled event is set, + # then raises a CanceledException to stop execution immediately. + # + # When we get a CanceledException, we don't need to do anything - just pass and let the + # loop go to its next iteration, and the cancel event will be handled correctly. + pass + + except Exception as e: + error = traceback.format_exc() + + # Save error + session.session.set_node_error(invocation.id, error) + self._invoker.services.logger.error( + f"Error while invoking session {session.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}" + ) + self._invoker.services.logger.error(error) + + # Send error event + self._invoker.services.events.emit_invocation_error( + queue_batch_id=session.session_id, + queue_item_id=session.item_id, + queue_id=session.queue_id, + graph_execution_state_id=session.session.id, + node=invocation.model_dump(), + source_node_id=source_invocation_id, + error_type=e.__class__.__name__, + error=error, + ) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 519a45c237..4478360dfe 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -27,6 +27,7 @@ from threading import BoundedSemaphore, Lock from typing import Dict, List, Optional, Set import torch +from pydantic import BaseModel from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot @@ -50,6 +51,26 @@ GIG = 1073741824 MB = 2**20 +# GPU device can only be used by one thread at a time. +# The refcount indicates the number of models stored +# in it. +class GPUDeviceStatus(BaseModel): + """Track of which threads are using the GPU(s) on this system.""" + + device: torch.device + thread_id: int = 0 + refcount: int = 0 + + class Config: + """Configure the base model.""" + + arbitrary_types_allowed = True + + def __hash__(self) -> int: + """Allow to be added to a set.""" + return hash(str(torch.device)) + + class ModelCache(ModelCacheBase[AnyModel]): """Implementation of ModelCacheBase.""" @@ -79,7 +100,7 @@ class ModelCache(ModelCacheBase[AnyModel]): """ self._precision: torch.dtype = precision self._max_cache_size: float = max_cache_size - self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices() + self._execution_devices: Set[GPUDeviceStatus] = self._get_execution_devices(execution_devices) self._storage_device: torch.device = storage_device self._lock = threading.Lock() self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) @@ -91,9 +112,10 @@ class ModelCache(ModelCacheBase[AnyModel]): self._lock = Lock() self._free_execution_device = BoundedSemaphore(len(self._execution_devices)) - self._busy_execution_devices: Set[torch.device] = set() - self.logger.info(f"Using rendering device(s) {[self._device_name(x) for x in self._execution_devices]}") + self.logger.info( + f"Using rendering device(s): {', '.join(sorted([str(x.device) for x in self._execution_devices]))}" + ) @property def logger(self) -> Logger: @@ -108,22 +130,34 @@ class ModelCache(ModelCacheBase[AnyModel]): @property def execution_devices(self) -> Set[torch.device]: """Return the set of available execution devices.""" - return self._execution_devices + return {x.device for x in self._execution_devices} def acquire_execution_device(self, timeout: int = 0) -> torch.device: """Acquire and return an execution device (e.g. "cuda" for VRAM).""" - with self._lock: + current_thread = threading.current_thread().ident + assert current_thread is not None + + # first try to assign a device that is already executing on this thread + if claimed_devices := [x for x in self._execution_devices if x.thread_id == current_thread]: + claimed_devices[0].refcount += 1 + return claimed_devices[0].device + else: + # this thread is not currently using any gpu. Wait for a free one self._free_execution_device.acquire(timeout=timeout) - free_devices = self.execution_devices - self._busy_execution_devices - chosen_device = list(free_devices)[0] - self._busy_execution_devices.add(chosen_device) - return chosen_device + unclaimed_devices = [x for x in self._execution_devices if x.refcount == 0] + unclaimed_devices[0].thread_id = current_thread + unclaimed_devices[0].refcount += 1 + return unclaimed_devices[0].device def release_execution_device(self, device: torch.device) -> None: """Mark this execution device as unused.""" - with self._lock: - self._free_execution_device.release() - self._busy_execution_devices.remove(device) + current_thread = threading.current_thread().ident + for x in self._execution_devices: + if x.thread_id == current_thread and x.device == device: + x.refcount -= 1 + if x.refcount == 0: + x.thread_id = 0 + self._free_execution_device.release() @property def max_cache_size(self) -> float: @@ -174,7 +208,7 @@ class ModelCache(ModelCacheBase[AnyModel]): if key in self._cached_models: return self.make_room(size) - cache_record = CacheRecord(key, model, size) + cache_record = CacheRecord(key, model=model, size=size) self._cached_models[key] = cache_record self._cache_stack.append(key) @@ -361,14 +395,15 @@ class ModelCache(ModelCacheBase[AnyModel]): raise torch.cuda.OutOfMemoryError @staticmethod - def _get_execution_devices() -> Set[torch.device]: - default_device = choose_torch_device() - if default_device != torch.device("cuda"): - return {default_device} - - # we get here if the default device is cuda, and return each of the - # cuda devices. - return {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())} + def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[GPUDeviceStatus]: + if not devices: + default_device = choose_torch_device() + if default_device != torch.device("cuda"): + devices = {default_device} + else: + # we get here if the default device is cuda, and return each of the cuda devices. + devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())} + return {GPUDeviceStatus(device=x) for x in devices} @staticmethod def _device_name(device: torch.device) -> str: diff --git a/tests/conftest.py b/tests/conftest.py index 7a7fdf32bb..97fd46de9b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,6 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.images.images_default import ImageService from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from invokeai.app.services.invocation_services import InvocationServices -from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.invoker import Invoker from invokeai.backend.util.logging import InvokeAILogger from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401 @@ -48,7 +47,6 @@ def mock_services() -> InvocationServices: model_manager=None, # type: ignore download_queue=None, # type: ignore names=None, # type: ignore - performance_statistics=InvocationStatsService(), session_processor=None, # type: ignore session_queue=None, # type: ignore urls=None, # type: ignore