fixup unit tests and remove debugging statements

This commit is contained in:
Lincoln Stein 2024-06-02 18:19:29 -04:00
parent e26360f85b
commit 589a7959c0
11 changed files with 61 additions and 186 deletions

View File

@ -4,7 +4,6 @@ from logging import Logger
import torch import torch
import invokeai.backend.util.devices # horrible hack
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.app.services.shared.sqlite.sqlite_util import init_db

View File

@ -99,6 +99,7 @@ class CompelInvocation(BaseInvocation):
textual_inversion_manager=ti_manager, textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype, dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False, truncate_long_prompts=False,
device=TorchDevice.choose_torch_device(),
) )
conjunction = Compel.parse_prompt_string(self.prompt) conjunction = Compel.parse_prompt_string(self.prompt)
@ -113,6 +114,7 @@ class CompelInvocation(BaseInvocation):
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)]) conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput( return ConditioningOutput(
conditioning=ConditioningField( conditioning=ConditioningField(
conditioning_name=conditioning_name, conditioning_name=conditioning_name,

View File

@ -74,9 +74,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
) )
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats) self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def reset_stats(self): def reset_stats(self, graph_execution_state_id: str):
self._stats = {} self._stats.pop(graph_execution_state_id)
self._cache_stats = {} self._cache_stats.pop(graph_execution_state_id)
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary: def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id) graph_stats_summary = self._get_graph_summary(graph_execution_state_id)

View File

@ -76,8 +76,6 @@ class ModelManagerService(ModelManagerServiceBase):
ram_cache = ModelCache( ram_cache = ModelCache(
max_cache_size=app_config.ram, max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger, logger=logger,
) )
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache) convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)

View File

@ -1,7 +1,7 @@
import traceback import traceback
from contextlib import suppress from contextlib import suppress
from queue import Queue from queue import Queue
from threading import BoundedSemaphore, Thread, Lock from threading import BoundedSemaphore, Lock, Thread
from threading import Event as ThreadEvent from threading import Event as ThreadEvent
from typing import Optional, Set from typing import Optional, Set
@ -61,7 +61,9 @@ class DefaultSessionRunner(SessionRunnerBase):
self._on_after_run_session_callbacks = on_after_run_session_callbacks or [] self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
self._process_lock = Lock() self._process_lock = Lock()
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None) -> None: def start(
self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None
) -> None:
self._services = services self._services = services
self._cancel_event = cancel_event self._cancel_event = cancel_event
self._profiler = profiler self._profiler = profiler
@ -214,7 +216,7 @@ class DefaultSessionRunner(SessionRunnerBase):
# we don't care about that - suppress the error. # we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError): with suppress(GESStatsNotFoundError):
self._services.performance_statistics.log_stats(queue_item.session.id) self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats() self._services.performance_statistics.reset_stats(queue_item.session.id)
for callback in self._on_after_run_session_callbacks: for callback in self._on_after_run_session_callbacks:
callback(queue_item=queue_item) callback(queue_item=queue_item)
@ -384,7 +386,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
) )
worker.start() worker.start()
def stop(self, *args, **kwargs) -> None: def stop(self, *args, **kwargs) -> None:
self._stop_event.set() self._stop_event.set()
@ -465,7 +466,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
# Run the graph # Run the graph
# self.session_runner.run(queue_item=self._queue_item) # self.session_runner.run(queue_item=self._queue_item)
except Exception as e: except Exception:
# Wait for next polling interval or event to try again # Wait for next polling interval or event to try again
poll_now_event.wait(self._polling_interval) poll_now_event.wait(self._polling_interval)
continue continue
@ -494,7 +495,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device(): with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
# Run the session on the reserved GPU # Run the session on the reserved GPU
self.session_runner.run(queue_item=queue_item) self.session_runner.run(queue_item=queue_item)
except Exception as e: except Exception:
continue continue
finally: finally:
self._active_queue_items.remove(queue_item) self._active_queue_items.remove(queue_item)

View File

@ -239,6 +239,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
def __hash__(self) -> int: def __hash__(self) -> int:
return self.item_id return self.item_id
class SessionQueueItemDTO(SessionQueueItemWithoutGraph): class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass pass

View File

@ -325,7 +325,6 @@ class ConditioningInterface(InvocationContextInterface):
Returns: Returns:
The loaded conditioning data. The loaded conditioning data.
""" """
return self._services.conditioning.load(name) return self._services.conditioning.load(name)

View File

@ -43,26 +43,9 @@ T = TypeVar("T")
@dataclass @dataclass
class CacheRecord(Generic[T]): class CacheRecord(Generic[T]):
""" """Elements of the cache."""
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
"""
key: str key: str
model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int size: int
model: T model: T
loaded: bool = False loaded: bool = False
@ -130,28 +113,12 @@ class ModelCacheBase(ABC, Generic[T]):
""" """
pass pass
@property
@abstractmethod
def lazy_offloading(self) -> bool:
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@property @property
@abstractmethod @abstractmethod
def max_cache_size(self) -> float: def max_cache_size(self) -> float:
"""Return true if the cache is configured to lazily offload models in VRAM.""" """Return true if the cache is configured to lazily offload models in VRAM."""
pass pass
@abstractmethod
def offload_unlocked_models(self, size_required: int) -> None:
"""Offload from VRAM any models not actively in use."""
pass
@abstractmethod
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device."""
pass
@property @property
@abstractmethod @abstractmethod
def stats(self) -> Optional[CacheStats]: def stats(self) -> Optional[CacheStats]:

View File

@ -19,10 +19,8 @@ context. Use like this:
""" """
import gc import gc
import math
import sys import sys
import threading import threading
import time
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from logging import Logger from logging import Logger
from threading import BoundedSemaphore from threading import BoundedSemaphore
@ -31,7 +29,7 @@ from typing import Dict, Generator, List, Optional, Set
import torch import torch
from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -42,11 +40,6 @@ from .model_locker import ModelLocker
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously # Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0 DEFAULT_MAX_CACHE_SIZE = 6.0
# amount of GPU memory to hold in reserve for use by generations (GB)
# Empirically this value seems to improve performance without starving other
# processes.
DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
# actual size of a gig # actual size of a gig
GIG = 1073741824 GIG = 1073741824
@ -60,12 +53,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
def __init__( def __init__(
self, self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
storage_device: torch.device = torch.device("cpu"), storage_device: torch.device = torch.device("cpu"),
execution_devices: Optional[Set[torch.device]] = None, execution_devices: Optional[Set[torch.device]] = None,
precision: torch.dtype = torch.float16, precision: torch.dtype = torch.float16,
sequential_offload: bool = False, sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216, sha_chunksize: int = 16777216,
log_memory_usage: bool = False, log_memory_usage: bool = False,
logger: Optional[Logger] = None, logger: Optional[Logger] = None,
@ -76,18 +67,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
:param max_cache_size: Maximum size of the RAM cache [6.0 GB] :param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16] :param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour. behaviour.
""" """
# allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._precision: torch.dtype = precision self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._storage_device: torch.device = storage_device self._storage_device: torch.device = storage_device
self._ram_lock = threading.Lock() self._ram_lock = threading.Lock()
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
@ -111,11 +98,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""Return the logger used by the cache.""" """Return the logger used by the cache."""
return self._logger return self._logger
@property
def lazy_offloading(self) -> bool:
"""Return true if the cache is configured to lazily offload models in VRAM."""
return self._lazy_offloading
@property @property
def storage_device(self) -> torch.device: def storage_device(self) -> torch.device:
"""Return the storage device (e.g. "CPU" for RAM).""" """Return the storage device (e.g. "CPU" for RAM)."""
@ -233,8 +215,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
if key in self._cached_models: if key in self._cached_models:
return return
self.make_room(size) self.make_room(size)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None cache_record = CacheRecord(key, model=model, size=size)
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record self._cached_models[key] = cache_record
self._cache_stack.append(key) self._cache_stack.append(key)
@ -296,107 +277,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
else: else:
return model_key return model_key
def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.loaded:
continue
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device.
:param cache_entry: The CacheRecord for the model
:param target_device: The torch.device to move the model into
May raise a torch.cuda.OutOfMemoryError
"""
# These attributes are not in the base ModelMixin class but in various derived classes.
# Some models don't have these attributes, in which case they run in RAM/CPU.
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
return
source_device = cache_entry.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
try:
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self.storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device)
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self.logger.debug(
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def print_cuda_stats(self) -> None: def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics.""" """Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
@ -440,12 +320,43 @@ class ModelCache(ModelCacheBase[AnyModel]):
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos] model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key] cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug( self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}" f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
f" refs: {refs}"
) )
if not cache_entry.locked: # Expected refs:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug( self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
) )

View File

@ -2,6 +2,7 @@
Base class and implementation of a class that moves models in and out of VRAM. Base class and implementation of a class that moves models in and out of VRAM.
""" """
import copy
from typing import Optional from typing import Optional
import torch import torch
@ -54,14 +55,13 @@ class ModelLocker(ModelLockerBase):
# NOTE that the model has to have the to() method in order for this code to move it into GPU! # NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock() self._cache_entry.lock()
try: try:
if self._cache.lazy_offloading: # We wait for a gpu to be free - may raise a ValueError
self._cache.offload_unlocked_models(self._cache_entry.size) self._execution_device = self._cache.get_execution_device()
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
execution_device = self._cache.get_execution_device() model_in_gpu = copy.deepcopy(self._cache_entry.model)
self._cache.move_model_to_device(self._cache_entry, execution_device) if hasattr(model_in_gpu, "to"):
model_in_gpu.to(self._execution_device)
self._cache_entry.loaded = True self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {execution_device}")
self._cache.print_cuda_stats() self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting") self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
@ -70,15 +70,11 @@ class ModelLocker(ModelLockerBase):
except Exception: except Exception:
self._cache_entry.unlock() self._cache_entry.unlock()
raise raise
return model_in_gpu
return self.model
def unlock(self) -> None: def unlock(self) -> None:
"""Call upon exit from context.""" """Call upon exit from context."""
if not hasattr(self.model, "to"): if not hasattr(self.model, "to"):
return return
self._cache_entry.unlock() self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(0)
self._cache.print_cuda_stats() self._cache.print_cuda_stats()

View File

@ -54,6 +54,7 @@ def mock_services() -> InvocationServices:
workflow_records=None, # type: ignore workflow_records=None, # type: ignore
tensors=None, # type: ignore tensors=None, # type: ignore
conditioning=None, # type: ignore conditioning=None, # type: ignore
performance_statistics=None, # type: ignore
) )