diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md
index 98e8702c8f..2b843fe042 100644
--- a/docs/contributing/MODEL_MANAGER.md
+++ b/docs/contributing/MODEL_MANAGER.md
@@ -1345,7 +1345,7 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
config = InvokeAIAppConfig.get_config()
ram_cache = ModelCache(
- max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
+ max_cache_size=config.ram_cache_size, logger=logger
)
convert_cache = ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py
index c23dd3d908..a0928f37ac 100644
--- a/invokeai/app/invocations/compel.py
+++ b/invokeai/app/invocations/compel.py
@@ -58,65 +58,62 @@ class CompelInvocation(BaseInvocation):
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
- text_encoder_model = text_encoder_info.model
- assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
- yield (lora_info.model, lora.weight)
+ with lora_info as model:
+ yield (model, lora.weight)
del lora_info
return
- # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
-
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
- with (
- ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
- tokenizer,
- ti_manager,
- ),
- text_encoder_info as text_encoder,
- # Apply the LoRA after text_encoder has been moved to its target device for faster patching.
- ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
- # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
- ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
- ):
- assert isinstance(text_encoder, CLIPTextModel)
- compel = Compel(
- tokenizer=tokenizer,
- text_encoder=text_encoder,
- textual_inversion_manager=ti_manager,
- dtype_for_device_getter=torch_dtype,
- truncate_long_prompts=False,
- )
-
- conjunction = Compel.parse_prompt_string(self.prompt)
-
- if context.config.get().log_tokenization:
- log_tokenization_for_conjunction(conjunction, tokenizer)
-
- c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
-
- ec = ExtraConditioningInfo(
- tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
- cross_attention_control_args=options.get("cross_attention_control", None),
- )
-
- c = c.detach().to("cpu")
-
- conditioning_data = ConditioningFieldData(
- conditionings=[
- BasicConditioningInfo(
- embeds=c,
- extra_conditioning=ec,
+ with text_encoder_info as text_encoder:
+ with (
+ ModelPatcher.apply_ti(tokenizer_model, text_encoder, ti_list) as (
+ tokenizer,
+ ti_manager,
+ ),
+ # Apply the LoRA after text_encoder has been moved to its target device for faster patching.
+ ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
+ # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
+ ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
+ ):
+ assert isinstance(text_encoder, CLIPTextModel)
+ compel = Compel(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ textual_inversion_manager=ti_manager,
+ dtype_for_device_getter=torch_dtype,
+ truncate_long_prompts=False,
)
- ]
- )
- conditioning_name = context.conditioning.save(conditioning_data)
+ conjunction = Compel.parse_prompt_string(self.prompt)
+
+ if context.config.get().log_tokenization:
+ log_tokenization_for_conjunction(conjunction, tokenizer)
+
+ c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
+
+ ec = ExtraConditioningInfo(
+ tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
+ cross_attention_control_args=options.get("cross_attention_control", None),
+ )
+
+ c = c.detach().to("cpu")
+
+ conditioning_data = ConditioningFieldData(
+ conditionings=[
+ BasicConditioningInfo(
+ embeds=c,
+ extra_conditioning=ec,
+ )
+ ]
+ )
+
+ conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name)
@@ -137,8 +134,7 @@ class SDXLPromptInvocationBase:
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)
- text_encoder_model = text_encoder_info.model
- assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
+ assert isinstance(text_encoder_info.model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty
if prompt == "" and zero_on_empty:
@@ -174,55 +170,55 @@ class SDXLPromptInvocationBase:
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
- with (
- ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
- tokenizer,
- ti_manager,
- ),
- text_encoder_info as text_encoder,
- # Apply the LoRA after text_encoder has been moved to its target device for faster patching.
- ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
- # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
- ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
- ):
- assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
- text_encoder = cast(CLIPTextModel, text_encoder)
- compel = Compel(
- tokenizer=tokenizer,
- text_encoder=text_encoder,
- textual_inversion_manager=ti_manager,
- dtype_for_device_getter=torch_dtype,
- truncate_long_prompts=False, # TODO:
- returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
- requires_pooled=get_pooled,
- )
+ with text_encoder_info as text_encoder:
+ with (
+ ModelPatcher.apply_ti(tokenizer_model, text_encoder, ti_list) as (
+ tokenizer,
+ ti_manager,
+ ),
+ # Apply the LoRA after text_encoder has been moved to its target device for faster patching.
+ ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
+ # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
+ ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
+ ):
+ assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
+ text_encoder = cast(CLIPTextModel, text_encoder)
+ compel = Compel(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ textual_inversion_manager=ti_manager,
+ dtype_for_device_getter=torch_dtype,
+ truncate_long_prompts=False, # TODO:
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
+ requires_pooled=get_pooled,
+ )
- conjunction = Compel.parse_prompt_string(prompt)
+ conjunction = Compel.parse_prompt_string(prompt)
- if context.config.get().log_tokenization:
- # TODO: better logging for and syntax
- log_tokenization_for_conjunction(conjunction, tokenizer)
+ if context.config.get().log_tokenization:
+ # TODO: better logging for and syntax
+ log_tokenization_for_conjunction(conjunction, tokenizer)
- # TODO: ask for optimizations? to not run text_encoder twice
- c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
- if get_pooled:
- c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt])
- else:
- c_pooled = None
+ # TODO: ask for optimizations? to not run text_encoder twice
+ c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
+ if get_pooled:
+ c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt])
+ else:
+ c_pooled = None
- ec = ExtraConditioningInfo(
- tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
- cross_attention_control_args=options.get("cross_attention_control", None),
- )
+ ec = ExtraConditioningInfo(
+ tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
+ cross_attention_control_args=options.get("cross_attention_control", None),
+ )
- del tokenizer
- del text_encoder
- del tokenizer_info
- del text_encoder_info
+ del tokenizer
+ del text_encoder
+ del tokenizer_info
+ del text_encoder_info
- c = c.detach().to("cpu")
- if c_pooled is not None:
- c_pooled = c_pooled.detach().to("cpu")
+ c = c.detach().to("cpu")
+ if c_pooled is not None:
+ c_pooled = c_pooled.detach().to("cpu")
return c, c_pooled, ec
diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index ee579f4bc4..c56fae2c4f 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -23,7 +23,6 @@ INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_RAM_CACHE = 10.0
-DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
@@ -100,9 +99,7 @@ class InvokeAIAppConfig(BaseSettings):
profile_prefix: An optional prefix for profile output files.
profiles_dir: Path to profiles output directory.
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
- vram: Amount of VRAM reserved for model storage (GB).
convert_cache: Maximum size of on-disk converted models cache (GB).
- lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
@@ -168,9 +165,7 @@ class InvokeAIAppConfig(BaseSettings):
# CACHE
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
- vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).")
- lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
# DEVICE
@@ -372,9 +367,6 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
- # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
- if k == "max_vram_cache_size" and "vram" not in category_dict:
- parsed_config_dict["vram"] = v
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py
index b8ce56eb16..95a681f7d2 100644
--- a/invokeai/backend/model_manager/load/load_base.py
+++ b/invokeai/backend/model_manager/load/load_base.py
@@ -28,8 +28,7 @@ class LoadedModel:
def __enter__(self) -> AnyModel:
"""Context entry."""
- self._locker.lock()
- return self.model
+ return self._locker.lock()
def __exit__(self, *args: Any, **kwargs: Any) -> None:
"""Context exit."""
diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py
index c54a35f15a..1d6a4f15db 100644
--- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py
+++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py
@@ -45,8 +45,8 @@ class CacheRecord(Generic[T]):
"""Elements of the cache."""
key: str
- model: T
size: int
+ model: T
loaded: bool = False
_locks: int = 0
@@ -109,28 +109,12 @@ class ModelCacheBase(ABC, Generic[T]):
"""Release a previously-acquired execution device."""
pass
- @property
- @abstractmethod
- def lazy_offloading(self) -> bool:
- """Return true if the cache is configured to lazily offload models in VRAM."""
- pass
-
@property
@abstractmethod
def max_cache_size(self) -> float:
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
- @abstractmethod
- def offload_unlocked_models(self, size_required: int) -> None:
- """Offload from VRAM any models not actively in use."""
- pass
-
- @abstractmethod
- def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
- """Move model into the indicated device."""
- pass
-
@property
@abstractmethod
def stats(self) -> CacheStats:
diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
index 26185b2fba..82935ef786 100644
--- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
+++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
@@ -19,9 +19,7 @@ context. Use like this:
"""
import gc
-import math
import sys
-import time
from contextlib import suppress
from logging import Logger
from threading import BoundedSemaphore, Lock
@@ -30,7 +28,7 @@ from typing import Dict, List, Optional, Set
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
-from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
+from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
@@ -44,9 +42,6 @@ if choose_torch_device() == torch.device("mps"):
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
-# amount of GPU memory to hold in reserve for use by generations (GB)
-DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
-
# actual size of a gig
GIG = 1073741824
@@ -60,12 +55,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
def __init__(
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
- max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
storage_device: torch.device = torch.device("cpu"),
execution_devices: Optional[Set[torch.device]] = None,
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
- lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
@@ -77,18 +70,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
:param execution_devices: Set of torch device to load active model into [calculated]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
- :param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
- # allow lazy offloading only when vram cache enabled
- self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
- self._max_vram_cache_size: float = max_vram_cache_size
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
self._storage_device: torch.device = storage_device
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
@@ -101,7 +90,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._lock = Lock()
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
self._busy_execution_devices: Set[torch.device] = set()
-
+
self.logger.info(f"Using rendering device(s) {[self._device_name(x) for x in self._execution_devices]}")
@property
@@ -109,11 +98,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""Return the logger used by the cache."""
return self._logger
- @property
- def lazy_offloading(self) -> bool:
- """Return true if the cache is configured to lazily offload models in VRAM."""
- return self._lazy_offloading
-
@property
def storage_device(self) -> torch.device:
"""Return the storage device (e.g. "CPU" for RAM)."""
@@ -181,7 +165,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
key = self._make_cache_key(key, submodel_type)
assert key not in self._cached_models
- cache_record = CacheRecord(key, model, size)
+ cache_record = CacheRecord(key=key, model=model, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
@@ -242,87 +226,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
else:
return model_key
- def offload_unlocked_models(self, size_required: int) -> None:
- """Move any unused models from VRAM."""
- reserved = self._max_vram_cache_size * GIG
- vram_in_use = torch.cuda.memory_allocated() + size_required
- self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
- for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
- if vram_in_use <= reserved:
- break
- if not cache_entry.loaded:
- continue
- if not cache_entry.locked:
- self.move_model_to_device(cache_entry, self.storage_device)
- cache_entry.loaded = False
- vram_in_use = torch.cuda.memory_allocated() + size_required
- self.logger.debug(
- f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
- )
-
- torch.cuda.empty_cache()
- if choose_torch_device() == torch.device("mps"):
- mps.empty_cache()
-
- def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
- """Move model into the indicated device.
-
- :param cache_entry: The CacheRecord for the model
- :param target_device: The torch.device to move the model into
-
- May raise a torch.cuda.OutOfMemoryError
- """
- # These attributes are not in the base ModelMixin class but in various derived classes.
- # Some models don't have these attributes, in which case they run in RAM/CPU.
- self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
- if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
- return
-
- source_device = cache_entry.model.device
-
- # Note: We compare device types only so that 'cuda' == 'cuda:0'.
- # This would need to be revised to support multi-GPU.
- if torch.device(source_device).type == torch.device(target_device).type:
- return
-
- # may raise an exception here if insufficient GPU VRAM
- self._check_free_vram(target_device, cache_entry.size)
-
- start_model_to_time = time.time()
- snapshot_before = self._capture_memory_snapshot()
- cache_entry.model.to(target_device)
- snapshot_after = self._capture_memory_snapshot()
- end_model_to_time = time.time()
- self.logger.debug(
- f"Moved model '{cache_entry.key}' from {source_device} to"
- f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
- f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
- f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
- )
-
- if (
- snapshot_before is not None
- and snapshot_after is not None
- and snapshot_before.vram is not None
- and snapshot_after.vram is not None
- ):
- vram_change = abs(snapshot_before.vram - snapshot_after.vram)
-
- # If the estimated model size does not match the change in VRAM, log a warning.
- if not math.isclose(
- vram_change,
- cache_entry.size,
- rel_tol=0.1,
- abs_tol=10 * MB,
- ):
- self.logger.debug(
- f"Moving model '{cache_entry.key}' from {source_device} to"
- f" {target_device} caused an unexpected change in VRAM usage. The model's"
- " estimated size may be incorrect. Estimated model size:"
- f" {(cache_entry.size/GIG):.3f} GB.\n"
- f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
- )
-
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py
index fa4eb1d5be..30c5dfa8c8 100644
--- a/invokeai/backend/model_manager/load/model_cache/model_locker.py
+++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py
@@ -2,6 +2,7 @@
Base class and implementation of a class that moves models in and out of VRAM.
"""
+import copy
from typing import Optional
import torch
@@ -41,15 +42,13 @@ class ModelLocker(ModelLockerBase):
self._cache_entry.lock()
try:
- if self._cache.lazy_offloading:
- self._cache.offload_unlocked_models(self._cache_entry.size)
-
# We wait for a gpu to be free - may raise a TimeoutError
self._execution_device = self._cache.acquire_execution_device(MAX_GPU_WAIT)
- self._cache.move_model_to_device(self._cache_entry, self._execution_device)
- self._cache_entry.loaded = True
-
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
+ model_in_gpu = copy.deepcopy(self._cache_entry.model)
+ if hasattr(model_in_gpu, "to"):
+ model_in_gpu.to(self._execution_device)
+ self._cache_entry.loaded = True
self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
@@ -58,7 +57,7 @@ class ModelLocker(ModelLockerBase):
except Exception:
self._cache_entry.unlock()
raise
- return self.model
+ return model_in_gpu
def unlock(self) -> None:
"""Call upon exit from context."""
@@ -68,6 +67,10 @@ class ModelLocker(ModelLockerBase):
self._cache_entry.unlock()
if self._execution_device:
self._cache.release_execution_device(self._execution_device)
- if not self._cache.lazy_offloading:
- self._cache.offload_unlocked_models(self._cache_entry.size)
- self._cache.print_cuda_stats()
+
+ try:
+ torch.cuda.empty_cache()
+ torch.mps.empty_cache()
+ except Exception:
+ pass
+ self._cache.print_cuda_stats()