mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fixup unit tests and remove debugging statements
This commit is contained in:
parent
e26360f85b
commit
589a7959c0
@ -4,7 +4,6 @@ from logging import Logger
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import invokeai.backend.util.devices # horrible hack
|
|
||||||
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||||
|
@ -99,6 +99,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||||
truncate_long_prompts=False,
|
truncate_long_prompts=False,
|
||||||
|
device=TorchDevice.choose_torch_device(),
|
||||||
)
|
)
|
||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||||
@ -113,6 +114,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
|
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
|
||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
|
||||||
return ConditioningOutput(
|
return ConditioningOutput(
|
||||||
conditioning=ConditioningField(
|
conditioning=ConditioningField(
|
||||||
conditioning_name=conditioning_name,
|
conditioning_name=conditioning_name,
|
||||||
|
@ -74,9 +74,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
)
|
)
|
||||||
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
||||||
|
|
||||||
def reset_stats(self):
|
def reset_stats(self, graph_execution_state_id: str):
|
||||||
self._stats = {}
|
self._stats.pop(graph_execution_state_id)
|
||||||
self._cache_stats = {}
|
self._cache_stats.pop(graph_execution_state_id)
|
||||||
|
|
||||||
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
|
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
|
||||||
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
|
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
|
||||||
|
@ -76,8 +76,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
|
|
||||||
ram_cache = ModelCache(
|
ram_cache = ModelCache(
|
||||||
max_cache_size=app_config.ram,
|
max_cache_size=app_config.ram,
|
||||||
max_vram_cache_size=app_config.vram,
|
|
||||||
lazy_offloading=app_config.lazy_offload,
|
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import traceback
|
import traceback
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from threading import BoundedSemaphore, Thread, Lock
|
from threading import BoundedSemaphore, Lock, Thread
|
||||||
from threading import Event as ThreadEvent
|
from threading import Event as ThreadEvent
|
||||||
from typing import Optional, Set
|
from typing import Optional, Set
|
||||||
|
|
||||||
@ -61,7 +61,9 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
|
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
|
||||||
self._process_lock = Lock()
|
self._process_lock = Lock()
|
||||||
|
|
||||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None) -> None:
|
def start(
|
||||||
|
self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None
|
||||||
|
) -> None:
|
||||||
self._services = services
|
self._services = services
|
||||||
self._cancel_event = cancel_event
|
self._cancel_event = cancel_event
|
||||||
self._profiler = profiler
|
self._profiler = profiler
|
||||||
@ -214,7 +216,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
# we don't care about that - suppress the error.
|
# we don't care about that - suppress the error.
|
||||||
with suppress(GESStatsNotFoundError):
|
with suppress(GESStatsNotFoundError):
|
||||||
self._services.performance_statistics.log_stats(queue_item.session.id)
|
self._services.performance_statistics.log_stats(queue_item.session.id)
|
||||||
self._services.performance_statistics.reset_stats()
|
self._services.performance_statistics.reset_stats(queue_item.session.id)
|
||||||
|
|
||||||
for callback in self._on_after_run_session_callbacks:
|
for callback in self._on_after_run_session_callbacks:
|
||||||
callback(queue_item=queue_item)
|
callback(queue_item=queue_item)
|
||||||
@ -384,7 +386,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
)
|
)
|
||||||
worker.start()
|
worker.start()
|
||||||
|
|
||||||
|
|
||||||
def stop(self, *args, **kwargs) -> None:
|
def stop(self, *args, **kwargs) -> None:
|
||||||
self._stop_event.set()
|
self._stop_event.set()
|
||||||
|
|
||||||
@ -465,7 +466,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
# Run the graph
|
# Run the graph
|
||||||
# self.session_runner.run(queue_item=self._queue_item)
|
# self.session_runner.run(queue_item=self._queue_item)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# Wait for next polling interval or event to try again
|
# Wait for next polling interval or event to try again
|
||||||
poll_now_event.wait(self._polling_interval)
|
poll_now_event.wait(self._polling_interval)
|
||||||
continue
|
continue
|
||||||
@ -494,7 +495,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
|
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
|
||||||
# Run the session on the reserved GPU
|
# Run the session on the reserved GPU
|
||||||
self.session_runner.run(queue_item=queue_item)
|
self.session_runner.run(queue_item=queue_item)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
finally:
|
finally:
|
||||||
self._active_queue_items.remove(queue_item)
|
self._active_queue_items.remove(queue_item)
|
||||||
|
@ -239,6 +239,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
|||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return self.item_id
|
return self.item_id
|
||||||
|
|
||||||
|
|
||||||
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -325,7 +325,6 @@ class ConditioningInterface(InvocationContextInterface):
|
|||||||
Returns:
|
Returns:
|
||||||
The loaded conditioning data.
|
The loaded conditioning data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._services.conditioning.load(name)
|
return self._services.conditioning.load(name)
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,26 +43,9 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CacheRecord(Generic[T]):
|
class CacheRecord(Generic[T]):
|
||||||
"""
|
"""Elements of the cache."""
|
||||||
Elements of the cache:
|
|
||||||
|
|
||||||
key: Unique key for each model, same as used in the models database.
|
|
||||||
model: Model in memory.
|
|
||||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
|
||||||
used as a template for creating a copy in the VRAM.
|
|
||||||
size: Size of the model
|
|
||||||
loaded: True if the model's state dict is currently in VRAM
|
|
||||||
|
|
||||||
Before a model is executed, the state_dict template is copied into VRAM,
|
|
||||||
and then injected into the model. When the model is finished, the VRAM
|
|
||||||
copy of the state dict is deleted, and the RAM version is reinjected
|
|
||||||
into the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
key: str
|
key: str
|
||||||
model: T
|
|
||||||
device: torch.device
|
|
||||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
|
||||||
size: int
|
size: int
|
||||||
model: T
|
model: T
|
||||||
loaded: bool = False
|
loaded: bool = False
|
||||||
@ -130,28 +113,12 @@ class ModelCacheBase(ABC, Generic[T]):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def lazy_offloading(self) -> bool:
|
|
||||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def max_cache_size(self) -> float:
|
def max_cache_size(self) -> float:
|
||||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def offload_unlocked_models(self, size_required: int) -> None:
|
|
||||||
"""Offload from VRAM any models not actively in use."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
|
||||||
"""Move model into the indicated device."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def stats(self) -> Optional[CacheStats]:
|
def stats(self) -> Optional[CacheStats]:
|
||||||
|
@ -19,10 +19,8 @@ context. Use like this:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import math
|
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
from contextlib import contextmanager, suppress
|
from contextlib import contextmanager, suppress
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from threading import BoundedSemaphore
|
from threading import BoundedSemaphore
|
||||||
@ -31,7 +29,7 @@ from typing import Dict, Generator, List, Optional, Set
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
@ -42,11 +40,6 @@ from .model_locker import ModelLocker
|
|||||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||||
|
|
||||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
|
||||||
# Empirically this value seems to improve performance without starving other
|
|
||||||
# processes.
|
|
||||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
|
|
||||||
|
|
||||||
# actual size of a gig
|
# actual size of a gig
|
||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
|
|
||||||
@ -60,12 +53,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
|
||||||
storage_device: torch.device = torch.device("cpu"),
|
storage_device: torch.device = torch.device("cpu"),
|
||||||
execution_devices: Optional[Set[torch.device]] = None,
|
execution_devices: Optional[Set[torch.device]] = None,
|
||||||
precision: torch.dtype = torch.float16,
|
precision: torch.dtype = torch.float16,
|
||||||
sequential_offload: bool = False,
|
sequential_offload: bool = False,
|
||||||
lazy_offloading: bool = True,
|
|
||||||
sha_chunksize: int = 16777216,
|
sha_chunksize: int = 16777216,
|
||||||
log_memory_usage: bool = False,
|
log_memory_usage: bool = False,
|
||||||
logger: Optional[Logger] = None,
|
logger: Optional[Logger] = None,
|
||||||
@ -76,18 +67,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||||
:param precision: Precision for loaded models [torch.float16]
|
:param precision: Precision for loaded models [torch.float16]
|
||||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
|
||||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||||
behaviour.
|
behaviour.
|
||||||
"""
|
"""
|
||||||
# allow lazy offloading only when vram cache enabled
|
|
||||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
|
||||||
self._precision: torch.dtype = precision
|
self._precision: torch.dtype = precision
|
||||||
self._max_cache_size: float = max_cache_size
|
self._max_cache_size: float = max_cache_size
|
||||||
self._max_vram_cache_size: float = max_vram_cache_size
|
|
||||||
self._storage_device: torch.device = storage_device
|
self._storage_device: torch.device = storage_device
|
||||||
self._ram_lock = threading.Lock()
|
self._ram_lock = threading.Lock()
|
||||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||||
@ -111,11 +98,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
"""Return the logger used by the cache."""
|
"""Return the logger used by the cache."""
|
||||||
return self._logger
|
return self._logger
|
||||||
|
|
||||||
@property
|
|
||||||
def lazy_offloading(self) -> bool:
|
|
||||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
|
||||||
return self._lazy_offloading
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def storage_device(self) -> torch.device:
|
def storage_device(self) -> torch.device:
|
||||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||||
@ -233,8 +215,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
if key in self._cached_models:
|
if key in self._cached_models:
|
||||||
return
|
return
|
||||||
self.make_room(size)
|
self.make_room(size)
|
||||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
cache_record = CacheRecord(key, model=model, size=size)
|
||||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
|
||||||
self._cached_models[key] = cache_record
|
self._cached_models[key] = cache_record
|
||||||
self._cache_stack.append(key)
|
self._cache_stack.append(key)
|
||||||
|
|
||||||
@ -296,107 +277,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
else:
|
else:
|
||||||
return model_key
|
return model_key
|
||||||
|
|
||||||
def offload_unlocked_models(self, size_required: int) -> None:
|
|
||||||
"""Move any unused models from VRAM."""
|
|
||||||
reserved = self._max_vram_cache_size * GIG
|
|
||||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
|
||||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
|
||||||
if vram_in_use <= reserved:
|
|
||||||
break
|
|
||||||
if not cache_entry.loaded:
|
|
||||||
continue
|
|
||||||
if not cache_entry.locked:
|
|
||||||
self.move_model_to_device(cache_entry, self.storage_device)
|
|
||||||
cache_entry.loaded = False
|
|
||||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
|
||||||
self.logger.debug(
|
|
||||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
|
||||||
)
|
|
||||||
|
|
||||||
TorchDevice.empty_cache()
|
|
||||||
|
|
||||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
|
||||||
"""Move model into the indicated device.
|
|
||||||
|
|
||||||
:param cache_entry: The CacheRecord for the model
|
|
||||||
:param target_device: The torch.device to move the model into
|
|
||||||
|
|
||||||
May raise a torch.cuda.OutOfMemoryError
|
|
||||||
"""
|
|
||||||
# These attributes are not in the base ModelMixin class but in various derived classes.
|
|
||||||
# Some models don't have these attributes, in which case they run in RAM/CPU.
|
|
||||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
|
||||||
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
|
||||||
return
|
|
||||||
|
|
||||||
source_device = cache_entry.device
|
|
||||||
|
|
||||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
|
||||||
# This would need to be revised to support multi-GPU.
|
|
||||||
if torch.device(source_device).type == torch.device(target_device).type:
|
|
||||||
return
|
|
||||||
|
|
||||||
# This roundabout method for moving the model around is done to avoid
|
|
||||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
|
||||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
|
||||||
# RAM to a new state dict in VRAM, and then inject it into the model.
|
|
||||||
# This operation is slightly faster than running `to()` on the whole model.
|
|
||||||
#
|
|
||||||
# When the model needs to be removed from VRAM we simply delete the copy
|
|
||||||
# of the state dict in VRAM, and reinject the state dict that is cached
|
|
||||||
# in RAM into the model. So this operation is very fast.
|
|
||||||
start_model_to_time = time.time()
|
|
||||||
snapshot_before = self._capture_memory_snapshot()
|
|
||||||
|
|
||||||
try:
|
|
||||||
if cache_entry.state_dict is not None:
|
|
||||||
assert hasattr(cache_entry.model, "load_state_dict")
|
|
||||||
if target_device == self.storage_device:
|
|
||||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
|
||||||
else:
|
|
||||||
new_dict: Dict[str, torch.Tensor] = {}
|
|
||||||
for k, v in cache_entry.state_dict.items():
|
|
||||||
new_dict[k] = v.to(torch.device(target_device), copy=True)
|
|
||||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
|
||||||
cache_entry.model.to(target_device)
|
|
||||||
cache_entry.device = target_device
|
|
||||||
except Exception as e: # blow away cache entry
|
|
||||||
self._delete_cache_entry(cache_entry)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
snapshot_after = self._capture_memory_snapshot()
|
|
||||||
end_model_to_time = time.time()
|
|
||||||
self.logger.debug(
|
|
||||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
|
||||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
|
||||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
|
||||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
snapshot_before is not None
|
|
||||||
and snapshot_after is not None
|
|
||||||
and snapshot_before.vram is not None
|
|
||||||
and snapshot_after.vram is not None
|
|
||||||
):
|
|
||||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
|
||||||
|
|
||||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
|
||||||
if not math.isclose(
|
|
||||||
vram_change,
|
|
||||||
cache_entry.size,
|
|
||||||
rel_tol=0.1,
|
|
||||||
abs_tol=10 * MB,
|
|
||||||
):
|
|
||||||
self.logger.debug(
|
|
||||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
|
||||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
|
||||||
" estimated size may be incorrect. Estimated model size:"
|
|
||||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
|
||||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def print_cuda_stats(self) -> None:
|
def print_cuda_stats(self) -> None:
|
||||||
"""Log CUDA diagnostics."""
|
"""Log CUDA diagnostics."""
|
||||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||||
@ -440,12 +320,43 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||||
model_key = self._cache_stack[pos]
|
model_key = self._cache_stack[pos]
|
||||||
cache_entry = self._cached_models[model_key]
|
cache_entry = self._cached_models[model_key]
|
||||||
|
|
||||||
|
refs = sys.getrefcount(cache_entry.model)
|
||||||
|
|
||||||
|
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
|
||||||
|
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
|
||||||
|
# https://docs.python.org/3/library/gc.html#gc.get_referrers
|
||||||
|
|
||||||
|
# manualy clear local variable references of just finished function calls
|
||||||
|
# for some reason python don't want to collect it even by gc.collect() immidiately
|
||||||
|
if refs > 2:
|
||||||
|
while True:
|
||||||
|
cleared = False
|
||||||
|
for referrer in gc.get_referrers(cache_entry.model):
|
||||||
|
if type(referrer).__name__ == "frame":
|
||||||
|
# RuntimeError: cannot clear an executing frame
|
||||||
|
with suppress(RuntimeError):
|
||||||
|
referrer.clear()
|
||||||
|
cleared = True
|
||||||
|
# break
|
||||||
|
|
||||||
|
# repeat if referrers changes(due to frame clear), else exit loop
|
||||||
|
if cleared:
|
||||||
|
gc.collect()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
|
||||||
|
f" refs: {refs}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not cache_entry.locked:
|
# Expected refs:
|
||||||
|
# 1 from cache_entry
|
||||||
|
# 1 from getrefcount function
|
||||||
|
# 1 from onnx runtime object
|
||||||
|
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||||
)
|
)
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
Base class and implementation of a class that moves models in and out of VRAM.
|
Base class and implementation of a class that moves models in and out of VRAM.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -54,14 +55,13 @@ class ModelLocker(ModelLockerBase):
|
|||||||
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
||||||
self._cache_entry.lock()
|
self._cache_entry.lock()
|
||||||
try:
|
try:
|
||||||
if self._cache.lazy_offloading:
|
# We wait for a gpu to be free - may raise a ValueError
|
||||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
self._execution_device = self._cache.get_execution_device()
|
||||||
|
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
|
||||||
execution_device = self._cache.get_execution_device()
|
model_in_gpu = copy.deepcopy(self._cache_entry.model)
|
||||||
self._cache.move_model_to_device(self._cache_entry, execution_device)
|
if hasattr(model_in_gpu, "to"):
|
||||||
|
model_in_gpu.to(self._execution_device)
|
||||||
self._cache_entry.loaded = True
|
self._cache_entry.loaded = True
|
||||||
|
|
||||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {execution_device}")
|
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||||
@ -70,15 +70,11 @@ class ModelLocker(ModelLockerBase):
|
|||||||
except Exception:
|
except Exception:
|
||||||
self._cache_entry.unlock()
|
self._cache_entry.unlock()
|
||||||
raise
|
raise
|
||||||
|
return model_in_gpu
|
||||||
return self.model
|
|
||||||
|
|
||||||
def unlock(self) -> None:
|
def unlock(self) -> None:
|
||||||
"""Call upon exit from context."""
|
"""Call upon exit from context."""
|
||||||
if not hasattr(self.model, "to"):
|
if not hasattr(self.model, "to"):
|
||||||
return
|
return
|
||||||
|
|
||||||
self._cache_entry.unlock()
|
self._cache_entry.unlock()
|
||||||
if not self._cache.lazy_offloading:
|
|
||||||
self._cache.offload_unlocked_models(0)
|
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
|
@ -54,6 +54,7 @@ def mock_services() -> InvocationServices:
|
|||||||
workflow_records=None, # type: ignore
|
workflow_records=None, # type: ignore
|
||||||
tensors=None, # type: ignore
|
tensors=None, # type: ignore
|
||||||
conditioning=None, # type: ignore
|
conditioning=None, # type: ignore
|
||||||
|
performance_statistics=None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user