parallel processing working on single-GPU, not tested on multi

This commit is contained in:
Lincoln Stein 2024-04-01 00:07:47 -04:00
parent cef51ad80d
commit 9df0980c46
8 changed files with 278 additions and 209 deletions

View File

@ -23,7 +23,6 @@ from ..services.image_records.image_records_sqlite import SqliteImageRecordStora
from ..services.images.images_default import ImageService
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
from ..services.model_manager.model_manager_default import ModelManagerService
@ -102,7 +101,6 @@ class ApiDependencies:
events=events,
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor()
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
@ -125,7 +123,6 @@ class ApiDependencies:
model_manager=model_manager,
download_queue=download_queue_service,
names=names,
performance_statistics=performance_statistics,
session_processor=session_processor,
session_queue=session_queue,
urls=urls,

View File

@ -749,7 +749,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
mask = mask.to(device=unet.device, dtype=unet.dtype)
if masked_latents is not None:
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,

View File

@ -24,7 +24,7 @@ DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_RAM_CACHE = 10.0
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
@ -169,6 +169,7 @@ class InvokeAIAppConfig(BaseSettings):
# DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
devices: Optional[list[DEVICE]] = Field(default=None, description="List of execution devices; will override default device selected.")
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
# GENERATION
@ -178,6 +179,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
max_threads: int = Field(default=4, description="Maximum number of session queue execution threads.")
# NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")

View File

@ -24,7 +24,6 @@ if TYPE_CHECKING:
from .image_records.image_records_base import ImageRecordStorageBase
from .images.images_base import ImageServiceABC
from .invocation_cache.invocation_cache_base import InvocationCacheBase
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .model_images.model_images_base import ModelImageFileStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .names.names_base import NameServiceBase
@ -53,7 +52,6 @@ class InvocationServices:
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase",
session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase",
invocation_cache: "InvocationCacheBase",
@ -77,7 +75,6 @@ class InvocationServices:
self.model_images = model_images
self.model_manager = model_manager
self.download_queue = download_queue
self.performance_statistics = performance_statistics
self.session_queue = session_queue
self.session_processor = session_processor
self.invocation_cache = invocation_cache

View File

@ -1,8 +1,6 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from typing import Optional, Set
import torch
from typing_extensions import Self
@ -68,7 +66,6 @@ class ModelManagerService(ModelManagerServiceBase):
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_devices: Optional[Set[torch.device]] = None,
) -> Self:
"""
Construct the model manager service instance.
@ -78,6 +75,13 @@ class ModelManagerService(ModelManagerServiceBase):
logger = InvokeAILogger.get_logger(cls.__name__)
logger.setLevel(app_config.log_level.upper())
execution_devices = (
None
if app_config.devices is None
else None
if "auto" in app_config.devices
else {torch.device(x) for x in app_config.devices}
)
ram_cache = ModelCache(
max_cache_size=app_config.ram,
logger=logger,

View File

@ -1,8 +1,9 @@
import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
from queue import Queue
from threading import BoundedSemaphore, Lock, Thread
from threading import Event as ThreadEvent
from typing import Optional
from typing import Optional, Set
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
@ -10,6 +11,7 @@ from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
@ -23,7 +25,8 @@ from .session_processor_common import SessionProcessorStatus
class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker, polling_interval: int = 1) -> None:
self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None
self._queue_items: Set[int] = set()
self._sessions_to_cancel: Set[int] = set()
self._invocation: Optional[BaseInvocation] = None
self._resume_event = ThreadEvent()
@ -33,10 +36,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
self._thread_limit = self._invoker.services.model_manager.load.gpu_count
self._thread_limit = 1
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
self._polling_interval = polling_interval
self._worker_thread_count = self._invoker.services.configuration.max_threads
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
self._process_lock = Lock()
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
# the profiler will create a new profile for each session.
self._profiler = (
@ -49,6 +56,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None
)
# main session processor loop - single thread
self._thread = Thread(
name="session_processor",
target=self._process,
@ -61,6 +69,16 @@ class DefaultSessionProcessor(SessionProcessorBase):
)
self._thread.start()
# Session processor workers - multithreaded
self._invoker.services.logger.debug(f"Starting {self._worker_thread_count} session processing threads.")
for _i in range(0, self._worker_thread_count):
worker = Thread(
name="session_worker",
target=self._process_next_session,
daemon=True,
)
worker.start()
def stop(self, *args, **kwargs) -> None:
self._stop_event.set()
@ -70,18 +88,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
async def _on_queue_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
if (
event_name == "session_canceled"
and self._queue_item
and self._queue_item.item_id == event[1]["data"]["queue_item_id"]
):
if event_name == "session_canceled" and event[1]["data"]["queue_item_id"] in self._queue_items:
self._sessions_to_cancel.add(event[1]["data"]["queue_item_id"])
self._cancel_event.set()
self._poll_now()
elif (
event_name == "queue_cleared"
and self._queue_item
and self._queue_item.queue_id == event[1]["data"]["queue_id"]
):
elif event_name == "queue_cleared" and event[1]["data"]["queue_id"] in self._queue_items:
self._sessions_to_cancel.add(event[1]["data"]["queue_item_id"])
self._cancel_event.set()
self._poll_now()
elif event_name == "batch_enqueued":
@ -100,7 +112,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
def get_status(self) -> SessionProcessorStatus:
return SessionProcessorStatus(
is_started=self._resume_event.is_set(),
is_processing=self._queue_item is not None,
is_processing=len(self._queue_items) > 0,
)
def _process(
@ -109,7 +121,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event: ThreadEvent,
resume_event: ThreadEvent,
cancel_event: ThreadEvent,
):
) -> None:
# Outermost processor try block; any unhandled exception is a fatal processor error
try:
self._thread_semaphore.acquire()
@ -119,168 +131,21 @@ class DefaultSessionProcessor(SessionProcessorBase):
while not stop_event.is_set():
poll_now_event.clear()
# Middle processor try block; any unhandled exception is a non-fatal processor error
try:
# If we are paused, wait for resume event
resume_event.wait()
resume_event.wait()
# Get the next session to process
self._queue_item = self._invoker.services.session_queue.dequeue()
# Get the next session to process
session = self._invoker.services.session_queue.dequeue()
if self._queue_item is None:
# The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
continue
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id)
# Prepare invocations and take the first
self._invocation = self._queue_item.session.next()
# Loop over invocations until the session is complete or canceled
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
with self._invoker.services.performance_statistics.collect_stats(
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# Invoke the node
outputs = self._invocation.invoke_internal(
context=context, services=self._invoker.services
)
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Set the invocation to None to prepare for the next session
self._invocation = None
else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
else:
# The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
continue
except Exception:
# Non-fatal error in processor
self._invoker.services.logger.error(
f"Non-fatal error in session processor:\n{traceback.format_exc()}"
)
# Cancel the queue item
if self._queue_item is not None:
self._invoker.services.session_queue.cancel_queue_item(
self._queue_item.item_id, error=traceback.format_exc()
)
# Reset the invocation to None to prepare for the next session
self._invocation = None
# Immediately poll for next queue item
if session is None:
# The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
continue
self._queue_items.add(session.item_id)
self._session_worker_queue.put(session)
self._invoker.services.logger.debug(f"Executing queue item {session.item_id}")
cancel_event.clear()
except Exception:
# Fatal error in processor, log and pass - we're done here
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}")
@ -288,5 +153,177 @@ class DefaultSessionProcessor(SessionProcessorBase):
finally:
stop_event.clear()
poll_now_event.clear()
self._queue_item = None
self._queue_items.clear()
self._thread_semaphore.release()
def _process_next_session(self) -> None:
profiler = (
Profiler(
logger=self._invoker.services.logger,
output_dir=self._invoker.services.configuration.profiles_path,
prefix=self._invoker.services.configuration.profile_prefix,
)
if self._invoker.services.configuration.profile_graphs
else None
)
stats_service = InvocationStatsService()
stats_service.start(self._invoker)
while True:
# Outer try block. Any error here is a fatal processor error
try:
session = self._session_worker_queue.get()
if self._cancel_event.is_set():
if session.item_id in self._sessions_to_cancel:
print("DEBUG: CANCEL")
continue
if profiler is not None:
profiler.start(profile_id=session.session_id)
# Prepare invocations and take the first
with self._process_lock:
invocation = session.session.next()
# Loop over invocations until the session is complete or canceled
while invocation is not None:
if self._stop_event.is_set():
break
self._resume_event.wait()
self._process_next_invocation(session, invocation, stats_service)
# The session is complete if all invocations are complete or there was an error
if session.session.is_complete():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=session.batch_id,
queue_item_id=session.item_id,
queue_id=session.queue_id,
graph_execution_state_id=session.session.id,
)
# Log stats
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
stats_service.log_stats(session.session.id)
stats_service.reset_stats()
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
stats_service.dump_stats(
graph_execution_state_id=session.session.id, output_path=stats_path
)
self._queue_items.remove(session.item_id)
invocation = None
else:
# Prepare the next invocation
with self._process_lock:
invocation = session.session.next()
except Exception:
# Non-fatal error in processor
self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{traceback.format_exc()}")
# Cancel the queue item
if session is not None:
self._invoker.services.session_queue.cancel_queue_item(
session.item_id, error=traceback.format_exc()
)
finally:
self._session_worker_queue.task_done()
def _process_next_invocation(
self,
session: SessionQueueItem,
invocation: BaseInvocation,
stats_service: InvocationStatsService,
) -> None:
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = session.session.prepared_source_mapping[invocation.id]
self._invoker.services.logger.debug(f"Executing invocation {session.session.id}:{source_invocation_id}")
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=session.batch_id,
queue_item_id=session.item_id,
queue_id=session.queue_id,
graph_execution_state_id=session.session_id,
node=invocation.model_dump(),
source_node_id=source_invocation_id,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=source_invocation_id,
queue_item=session,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# Invoke the node
# title = invocation.UIConfig.title
with stats_service.collect_stats(invocation, session.session.id):
outputs = invocation.invoke_internal(context=context, services=self._invoker.services)
# Save outputs and history
session.session.complete(invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=session.batch_id,
queue_item_id=session.item_id,
queue_id=session.queue_id,
graph_execution_state_id=session.session.id,
node=invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
session.session.set_node_error(invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {session.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=session.session_id,
queue_item_id=session.item_id,
queue_id=session.queue_id,
graph_execution_state_id=session.session.id,
node=invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
)

View File

@ -27,6 +27,7 @@ from threading import BoundedSemaphore, Lock
from typing import Dict, List, Optional, Set
import torch
from pydantic import BaseModel
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
@ -50,6 +51,26 @@ GIG = 1073741824
MB = 2**20
# GPU device can only be used by one thread at a time.
# The refcount indicates the number of models stored
# in it.
class GPUDeviceStatus(BaseModel):
"""Track of which threads are using the GPU(s) on this system."""
device: torch.device
thread_id: int = 0
refcount: int = 0
class Config:
"""Configure the base model."""
arbitrary_types_allowed = True
def __hash__(self) -> int:
"""Allow to be added to a set."""
return hash(str(torch.device))
class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase."""
@ -79,7 +100,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
self._execution_devices: Set[GPUDeviceStatus] = self._get_execution_devices(execution_devices)
self._storage_device: torch.device = storage_device
self._lock = threading.Lock()
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
@ -91,9 +112,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._lock = Lock()
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
self._busy_execution_devices: Set[torch.device] = set()
self.logger.info(f"Using rendering device(s) {[self._device_name(x) for x in self._execution_devices]}")
self.logger.info(
f"Using rendering device(s): {', '.join(sorted([str(x.device) for x in self._execution_devices]))}"
)
@property
def logger(self) -> Logger:
@ -108,22 +130,34 @@ class ModelCache(ModelCacheBase[AnyModel]):
@property
def execution_devices(self) -> Set[torch.device]:
"""Return the set of available execution devices."""
return self._execution_devices
return {x.device for x in self._execution_devices}
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
"""Acquire and return an execution device (e.g. "cuda" for VRAM)."""
with self._lock:
current_thread = threading.current_thread().ident
assert current_thread is not None
# first try to assign a device that is already executing on this thread
if claimed_devices := [x for x in self._execution_devices if x.thread_id == current_thread]:
claimed_devices[0].refcount += 1
return claimed_devices[0].device
else:
# this thread is not currently using any gpu. Wait for a free one
self._free_execution_device.acquire(timeout=timeout)
free_devices = self.execution_devices - self._busy_execution_devices
chosen_device = list(free_devices)[0]
self._busy_execution_devices.add(chosen_device)
return chosen_device
unclaimed_devices = [x for x in self._execution_devices if x.refcount == 0]
unclaimed_devices[0].thread_id = current_thread
unclaimed_devices[0].refcount += 1
return unclaimed_devices[0].device
def release_execution_device(self, device: torch.device) -> None:
"""Mark this execution device as unused."""
with self._lock:
self._free_execution_device.release()
self._busy_execution_devices.remove(device)
current_thread = threading.current_thread().ident
for x in self._execution_devices:
if x.thread_id == current_thread and x.device == device:
x.refcount -= 1
if x.refcount == 0:
x.thread_id = 0
self._free_execution_device.release()
@property
def max_cache_size(self) -> float:
@ -174,7 +208,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
if key in self._cached_models:
return
self.make_room(size)
cache_record = CacheRecord(key, model, size)
cache_record = CacheRecord(key, model=model, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
@ -361,14 +395,15 @@ class ModelCache(ModelCacheBase[AnyModel]):
raise torch.cuda.OutOfMemoryError
@staticmethod
def _get_execution_devices() -> Set[torch.device]:
default_device = choose_torch_device()
if default_device != torch.device("cuda"):
return {default_device}
# we get here if the default device is cuda, and return each of the
# cuda devices.
return {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[GPUDeviceStatus]:
if not devices:
default_device = choose_torch_device()
if default_device != torch.device("cuda"):
devices = {default_device}
else:
# we get here if the default device is cuda, and return each of the cuda devices.
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
return {GPUDeviceStatus(device=x) for x in devices}
@staticmethod
def _device_name(device: torch.device) -> str:

View File

@ -17,7 +17,6 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401
@ -48,7 +47,6 @@ def mock_services() -> InvocationServices:
model_manager=None, # type: ignore
download_queue=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore