revert to old system for doing RAM <-> VRAM transfers; new way leaks memory

This commit is contained in:
Lincoln Stein
2024-04-17 09:51:57 -04:00
parent 84f5cbdd97
commit c3d1252892
5 changed files with 141 additions and 16 deletions

View File

@ -24,6 +24,7 @@ INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db") DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init") LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_RAM_CACHE = 10.0 DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0 DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "mps"] DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"] PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
@ -99,7 +100,9 @@ class InvokeAIAppConfig(BaseSettings):
profile_prefix: An optional prefix for profile output files. profile_prefix: An optional prefix for profile output files.
profiles_dir: Path to profiles output directory. profiles_dir: Path to profiles output directory.
ram: Maximum memory amount used by memory model cache for rapid switching (GB). ram: Maximum memory amount used by memory model cache for rapid switching (GB).
vram: Amount of VRAM reserved for model storage (GB).
convert_cache: Maximum size of on-disk converted models cache (GB). convert_cache: Maximum size of on-disk converted models cache (GB).
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour. log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `cuda:6`, `cuda:7`, `mps` device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `cuda:6`, `cuda:7`, `mps`
devices: List of execution devices; will override default device selected. devices: List of execution devices; will override default device selected.
@ -167,7 +170,9 @@ class InvokeAIAppConfig(BaseSettings):
# CACHE # CACHE
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).") ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).") convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).")
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.") log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
# DEVICE # DEVICE
@ -366,9 +371,6 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used # `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict: if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v parsed_config_dict["ram"] = v
# vram was removed in v4.0.2
if k in ["vram", "max_vram_cache_size", "lazy_offload"]:
continue
# autocast was removed in v4.0.1 # autocast was removed in v4.0.1
if k == "precision" and v == "autocast": if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto" parsed_config_dict["precision"] = "auto"
@ -419,6 +421,9 @@ def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig
def migrate_v4_0_1_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig: def migrate_v4_0_1_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.1 config dictionary to a current config object. """Migrate v4.0.1 config dictionary to a current config object.
A few new multi-GPU options were added in 4.0.2, and this simply
updates the schema label.
Args: Args:
config_dict: A dictionary of settings from a v4.0.1 config file. config_dict: A dictionary of settings from a v4.0.1 config file.
@ -426,15 +431,14 @@ def migrate_v4_0_1_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig
An instance of `InvokeAIAppConfig` with the migrated settings. An instance of `InvokeAIAppConfig` with the migrated settings.
""" """
parsed_config_dict: dict[str, Any] = {} parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items(): for k, _ in config_dict.items():
if k not in ["vram", "lazy_offload"]:
parsed_config_dict[k] = v
if k == "schema_version": if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict) config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config return config
# TO DO: replace this with a formal registration and migration system
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version. """Load and migrate a config file to the latest version.

View File

@ -76,6 +76,8 @@ class ModelManagerService(ModelManagerServiceBase):
ram_cache = ModelCache( ram_cache = ModelCache(
max_cache_size=app_config.ram, max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger, logger=logger,
) )
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache) convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)

View File

@ -113,12 +113,28 @@ class ModelCacheBase(ABC, Generic[T]):
""" """
pass pass
@property
@abstractmethod
def lazy_offloading(self) -> bool:
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@property @property
@abstractmethod @abstractmethod
def max_cache_size(self) -> float: def max_cache_size(self) -> float:
"""Return true if the cache is configured to lazily offload models in VRAM.""" """Return true if the cache is configured to lazily offload models in VRAM."""
pass pass
@abstractmethod
def offload_unlocked_models(self, size_required: int) -> None:
"""Offload from VRAM any models not actively in use."""
pass
@abstractmethod
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device."""
pass
@property @property
@abstractmethod @abstractmethod
def stats(self) -> Optional[CacheStats]: def stats(self) -> Optional[CacheStats]:

View File

@ -19,8 +19,10 @@ context. Use like this:
""" """
import gc import gc
import math
import sys import sys
import threading import threading
import time
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from logging import Logger from logging import Logger
from threading import BoundedSemaphore from threading import BoundedSemaphore
@ -29,7 +31,7 @@ from typing import Dict, Generator, List, Optional, Set
import torch import torch
from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -40,6 +42,11 @@ from .model_locker import ModelLocker
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously # Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0 DEFAULT_MAX_CACHE_SIZE = 6.0
# amount of GPU memory to hold in reserve for use by generations (GB)
# Empirically this value seems to improve performance without starving other
# processes.
DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
# actual size of a gig # actual size of a gig
GIG = 1073741824 GIG = 1073741824
@ -53,10 +60,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
def __init__( def __init__(
self, self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
storage_device: torch.device = torch.device("cpu"), storage_device: torch.device = torch.device("cpu"),
execution_devices: Optional[Set[torch.device]] = None, execution_devices: Optional[Set[torch.device]] = None,
precision: torch.dtype = torch.float16, precision: torch.dtype = torch.float16,
sequential_offload: bool = False, sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216, sha_chunksize: int = 16777216,
log_memory_usage: bool = False, log_memory_usage: bool = False,
logger: Optional[Logger] = None, logger: Optional[Logger] = None,
@ -67,14 +76,18 @@ class ModelCache(ModelCacheBase[AnyModel]):
:param max_cache_size: Maximum size of the RAM cache [6.0 GB] :param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16] :param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour. behaviour.
""" """
# allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._precision: torch.dtype = precision self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._storage_device: torch.device = storage_device self._storage_device: torch.device = storage_device
self._ram_lock = threading.Lock() self._ram_lock = threading.Lock()
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
@ -98,6 +111,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""Return the logger used by the cache.""" """Return the logger used by the cache."""
return self._logger return self._logger
@property
def lazy_offloading(self) -> bool:
"""Return true if the cache is configured to lazily offload models in VRAM."""
return self._lazy_offloading
@property @property
def storage_device(self) -> torch.device: def storage_device(self) -> torch.device:
"""Return the storage device (e.g. "CPU" for RAM).""" """Return the storage device (e.g. "CPU" for RAM)."""
@ -277,6 +295,87 @@ class ModelCache(ModelCacheBase[AnyModel]):
else: else:
return model_key return model_key
def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.loaded:
continue
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device.
:param cache_entry: The CacheRecord for the model
:param target_device: The torch.device to move the model into
May raise a torch.cuda.OutOfMemoryError
"""
# These attributes are not in the base ModelMixin class but in various derived classes.
# Some models don't have these attributes, in which case they run in RAM/CPU.
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
return
source_device = cache_entry.model.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
try:
cache_entry.model.to(target_device)
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self.logger.debug(
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def print_cuda_stats(self) -> None: def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics.""" """Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)

View File

@ -2,7 +2,6 @@
Base class and implementation of a class that moves models in and out of VRAM. Base class and implementation of a class that moves models in and out of VRAM.
""" """
import copy
from typing import Optional from typing import Optional
import torch import torch
@ -55,13 +54,14 @@ class ModelLocker(ModelLockerBase):
# NOTE that the model has to have the to() method in order for this code to move it into GPU! # NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock() self._cache_entry.lock()
try: try:
# We wait for a gpu to be free - may raise a ValueError if self._cache.lazy_offloading:
self._execution_device = self._cache.get_execution_device() self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
model_in_gpu = copy.deepcopy(self._cache_entry.model) execution_device = self._cache.get_execution_device()
if hasattr(model_in_gpu, "to"): self._cache.move_model_to_device(self._cache_entry, execution_device)
model_in_gpu.to(self._execution_device)
self._cache_entry.loaded = True self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {execution_device}")
self._cache.print_cuda_stats() self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting") self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
@ -70,11 +70,15 @@ class ModelLocker(ModelLockerBase):
except Exception: except Exception:
self._cache_entry.unlock() self._cache_entry.unlock()
raise raise
return model_in_gpu
return self.model
def unlock(self) -> None: def unlock(self) -> None:
"""Call upon exit from context.""" """Call upon exit from context."""
if not hasattr(self.model, "to"): if not hasattr(self.model, "to"):
return return
self._cache_entry.unlock() self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.print_cuda_stats() self._cache.print_cuda_stats()