mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Calculate model cache size limits dynamically based on the available RAM / VRAM.
This commit is contained in:
@ -13,7 +13,6 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import psutil
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
|
||||
@ -25,8 +24,6 @@ from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_RAM_CACHE = 10.0
|
||||
DEFAULT_VRAM_CACHE = 0.25
|
||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||
@ -36,24 +33,6 @@ LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||
CONFIG_SCHEMA_VERSION = "4.0.2"
|
||||
|
||||
|
||||
def get_default_ram_cache_size() -> float:
|
||||
"""Run a heuristic for the default RAM cache based on installed RAM."""
|
||||
|
||||
# On some machines, psutil.virtual_memory().total gives a value that is slightly less than the actual RAM, so the
|
||||
# limits are set slightly lower than than what we expect the actual RAM to be.
|
||||
|
||||
GB = 1024**3
|
||||
max_ram = psutil.virtual_memory().total / GB
|
||||
|
||||
if max_ram >= 60:
|
||||
return 15.0
|
||||
if max_ram >= 30:
|
||||
return 7.5
|
||||
if max_ram >= 14:
|
||||
return 4.0
|
||||
return 2.1 # 2.1 is just large enough for sd 1.5 ;-)
|
||||
|
||||
|
||||
class URLRegexTokenPair(BaseModel):
|
||||
url_regex: str = Field(description="Regular expression to match against the URL")
|
||||
token: str = Field(description="Token to use when the URL matches the regex")
|
||||
@ -103,11 +82,12 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
profile_graphs: Enable graph profiling using `cProfile`.
|
||||
profile_prefix: An optional prefix for profile output files.
|
||||
profiles_dir: Path to profiles output directory.
|
||||
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
|
||||
vram: Amount of VRAM reserved for model storage (GB).
|
||||
lazy_offload: Keep models in VRAM until their space is needed.
|
||||
ram: The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.
|
||||
vram: The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.
|
||||
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behaviour is out of beta.
|
||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. Partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM. If enabling this setting, make sure that your ram and vram cache limits are properly tuned.
|
||||
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
|
||||
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
@ -175,11 +155,12 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
|
||||
|
||||
# CACHE
|
||||
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
|
||||
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
|
||||
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
|
||||
ram: Optional[float] = Field(default=None, gt=0, description="The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.")
|
||||
vram: Optional[float] = Field(default=None, ge=0, description="The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.")
|
||||
lazy_offload: bool = Field(default=True, description="DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behaviour is out of beta.")
|
||||
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
|
||||
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. Partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM. If enabling this setting, make sure that your ram and vram cache limits are properly tuned.")
|
||||
device_working_mem_gb: float = Field(default=3, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
|
||||
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.")
|
||||
|
||||
# DEVICE
|
||||
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
|
||||
|
@ -82,11 +82,12 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
logger.setLevel(app_config.log_level.upper())
|
||||
|
||||
ram_cache = ModelCache(
|
||||
execution_device_working_mem_gb=app_config.device_working_mem_gb,
|
||||
enable_partial_loading=app_config.enable_partial_loading,
|
||||
max_ram_cache_size_gb=app_config.ram,
|
||||
max_vram_cache_size_gb=app_config.vram,
|
||||
enable_partial_loading=app_config.enable_partial_loading,
|
||||
logger=logger,
|
||||
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||
logger=logger,
|
||||
)
|
||||
loader = ModelLoadService(
|
||||
app_config=app_config,
|
||||
|
@ -4,6 +4,7 @@ import time
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
@ -75,9 +76,10 @@ class ModelCache:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_ram_cache_size_gb: float,
|
||||
max_vram_cache_size_gb: float,
|
||||
execution_device_working_mem_gb: float,
|
||||
enable_partial_loading: bool,
|
||||
max_ram_cache_size_gb: float | None = None,
|
||||
max_vram_cache_size_gb: float | None = None,
|
||||
execution_device: torch.device | str = "cuda",
|
||||
storage_device: torch.device | str = "cpu",
|
||||
log_memory_usage: bool = False,
|
||||
@ -85,6 +87,9 @@ class ModelCache:
|
||||
):
|
||||
"""Initialize the model RAM cache.
|
||||
|
||||
:param execution_device_working_mem_gb: The amount of working memory to keep on the GPU (in GB) i.e. non-model
|
||||
VRAM.
|
||||
:param enable_partial_loading: Whether to enable partial loading of models.
|
||||
:param max_ram_cache_size_gb: The maximum amount of CPU RAM to use for model caching in GB. This parameter is
|
||||
kept to maintain compatibility with previous versions of the model cache, but should be deprecated in the
|
||||
future. If set, this parameter overrides the default cache size logic.
|
||||
@ -99,12 +104,13 @@ class ModelCache:
|
||||
behaviour.
|
||||
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||
"""
|
||||
self._enable_partial_loading = enable_partial_loading
|
||||
self._execution_device_working_mem_gb = execution_device_working_mem_gb
|
||||
self._execution_device: torch.device = torch.device(execution_device)
|
||||
self._storage_device: torch.device = torch.device(storage_device)
|
||||
|
||||
self._max_ram_cache_size_gb = max_ram_cache_size_gb
|
||||
self._max_vram_cache_size_gb = max_vram_cache_size_gb
|
||||
self._enable_partial_loading = enable_partial_loading
|
||||
|
||||
self._logger = PrefixedLoggerAdapter(
|
||||
logger or InvokeAILogger.get_logger(self.__class__.__name__), "MODEL CACHE"
|
||||
@ -244,7 +250,7 @@ class ModelCache:
|
||||
f"Unlocked model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
|
||||
)
|
||||
|
||||
def _load_locked_model(self, cache_entry: CacheRecord) -> None:
|
||||
def _load_locked_model(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int] = None) -> None:
|
||||
"""Helper function for self.lock(). Loads a locked model into VRAM."""
|
||||
start_time = time.time()
|
||||
|
||||
@ -254,7 +260,7 @@ class ModelCache:
|
||||
model_total_bytes = cache_entry.cached_model.total_bytes()
|
||||
model_vram_needed = model_total_bytes - model_cur_vram_bytes
|
||||
|
||||
vram_available = self._get_vram_available()
|
||||
vram_available = self._get_vram_available(working_mem_bytes)
|
||||
self._logger.debug(
|
||||
f"Before unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
@ -267,7 +273,7 @@ class ModelCache:
|
||||
self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB")
|
||||
|
||||
# Check the updated vram_available after offloading.
|
||||
vram_available = self._get_vram_available()
|
||||
vram_available = self._get_vram_available(working_mem_bytes)
|
||||
self._logger.debug(
|
||||
f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
@ -280,7 +286,7 @@ class ModelCache:
|
||||
model_bytes_loaded = self._move_model_to_vram(cache_entry, vram_available + MB)
|
||||
|
||||
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
|
||||
vram_available = self._get_vram_available()
|
||||
vram_available = self._get_vram_available(working_mem_bytes)
|
||||
loaded_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
|
||||
self._logger.info(
|
||||
f"Loaded model '{cache_entry.key}' ({cache_entry.cached_model.model.__class__.__name__}) onto "
|
||||
@ -298,7 +304,7 @@ class ModelCache:
|
||||
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
|
||||
return cache_entry.cached_model.partial_load_to_vram(vram_available)
|
||||
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
|
||||
# Partial load is not supported, so we have no choice but to try and fit it all into VRAM.
|
||||
# Partial load is not supported, so we have not choice but to try and fit it all into VRAM.
|
||||
return cache_entry.cached_model.full_load_to_vram()
|
||||
else:
|
||||
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
|
||||
@ -322,10 +328,33 @@ class ModelCache:
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise
|
||||
|
||||
def _get_vram_available(self) -> int:
|
||||
"""Calculate the amount of additional VRAM available for the cache to use."""
|
||||
vram_total_available_to_cache = int(self._max_vram_cache_size_gb * GB)
|
||||
return vram_total_available_to_cache - self._get_vram_in_use()
|
||||
def _get_vram_available(self, working_mem_bytes: Optional[int] = None) -> int:
|
||||
"""Calculate the amount of additional VRAM available for the cache to use (takes into account the working
|
||||
memory).
|
||||
"""
|
||||
# If self._max_vram_cache_size_gb is set, then it overrides the default logic.
|
||||
if self._max_vram_cache_size_gb is not None:
|
||||
vram_total_available_to_cache = int(self._max_vram_cache_size_gb * GB)
|
||||
return vram_total_available_to_cache - self._get_vram_in_use()
|
||||
|
||||
working_mem_bytes_default = int(self._execution_device_working_mem_gb * GB)
|
||||
working_mem_bytes = max(working_mem_bytes or working_mem_bytes_default, working_mem_bytes_default)
|
||||
|
||||
if self._execution_device.type == "cuda":
|
||||
vram_reserved = torch.cuda.memory_reserved(self._execution_device)
|
||||
vram_free, _vram_total = torch.cuda.mem_get_info(self._execution_device)
|
||||
vram_available_to_process = vram_free + vram_reserved
|
||||
elif self._execution_device.type == "mps":
|
||||
vram_reserved = torch.mps.driver_allocated_memory()
|
||||
# TODO(ryand): Is it accurate that MPS shares memory with the CPU?
|
||||
vram_free = psutil.virtual_memory().available
|
||||
vram_available_to_process = vram_free + vram_reserved
|
||||
else:
|
||||
raise ValueError(f"Unsupported execution device: {self._execution_device.type}")
|
||||
|
||||
vram_total_available_to_cache = vram_available_to_process - working_mem_bytes
|
||||
vram_cur_available_to_cache = vram_total_available_to_cache - self._get_vram_in_use()
|
||||
return vram_cur_available_to_cache
|
||||
|
||||
def _get_vram_in_use(self) -> int:
|
||||
"""Get the amount of VRAM currently in use by the cache."""
|
||||
@ -341,8 +370,25 @@ class ModelCache:
|
||||
def _get_ram_available(self) -> int:
|
||||
"""Get the amount of RAM available for the cache to use, while keeping memory pressure under control."""
|
||||
|
||||
ram_total_available_to_cache = int(self._max_ram_cache_size_gb * GB)
|
||||
return ram_total_available_to_cache - self._get_ram_in_use()
|
||||
# If self._max_ram_cache_size_gb is set, then it overrides the default logic.
|
||||
if self._max_ram_cache_size_gb is not None:
|
||||
ram_total_available_to_cache = int(self._max_ram_cache_size_gb * GB)
|
||||
return ram_total_available_to_cache - self._get_ram_in_use()
|
||||
|
||||
virtual_memory = psutil.virtual_memory()
|
||||
ram_total = virtual_memory.total
|
||||
ram_available = virtual_memory.available
|
||||
ram_used = ram_total - ram_available
|
||||
|
||||
# The total size of all the models in the cache will often be larger than the amount of RAM reported by psutil
|
||||
# (due to lazy-loading and OS RAM caching behaviour). We could just rely on the psutil values, but it feels
|
||||
# like a bad idea to over-fill the model cache. So, for now, we'll try to keep the total size of models in the
|
||||
# cache under the total amount of system RAM.
|
||||
cache_ram_used = self._get_ram_in_use()
|
||||
ram_used = max(cache_ram_used, ram_used)
|
||||
|
||||
# Aim to keep 10% of RAM free.
|
||||
return int(ram_total * 0.9) - ram_used
|
||||
|
||||
def _get_ram_in_use(self) -> int:
|
||||
"""Get the amount of RAM currently in use."""
|
||||
|
@ -91,10 +91,11 @@ def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
|
||||
@pytest.fixture
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
|
||||
ram_cache = ModelCache(
|
||||
logger=InvokeAILogger.get_logger(),
|
||||
execution_device_working_mem_gb=mm2_app_config.device_working_mem_gb,
|
||||
enable_partial_loading=mm2_app_config.enable_partial_loading,
|
||||
max_ram_cache_size_gb=mm2_app_config.ram,
|
||||
max_vram_cache_size_gb=mm2_app_config.vram,
|
||||
enable_partial_loading=mm2_app_config.enable_partial_loading,
|
||||
logger=InvokeAILogger.get_logger(),
|
||||
)
|
||||
return ModelLoadService(
|
||||
app_config=mm2_app_config,
|
||||
|
Reference in New Issue
Block a user