add locking around thread-critical sections

This commit is contained in:
Lincoln Stein 2024-03-31 16:58:56 -04:00
parent a1dcab9c38
commit 9336a076de

View File

@ -20,6 +20,7 @@ context. Use like this:
import gc
import sys
import threading
from contextlib import suppress
from logging import Logger
from threading import BoundedSemaphore, Lock
@ -80,6 +81,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._max_cache_size: float = max_cache_size
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
self._storage_device: torch.device = storage_device
self._lock = threading.Lock()
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None
@ -162,12 +164,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
key = self._make_cache_key(key, submodel_type)
assert key not in self._cached_models
with self._lock:
key = self._make_cache_key(key, submodel_type)
assert key not in self._cached_models
cache_record = CacheRecord(key=key, model=model, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
cache_record = CacheRecord(key=key, model=model, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
def get(
self,
@ -185,36 +188,37 @@ class ModelCache(ModelCacheBase[AnyModel]):
This may raise an IndexError if the model is not in the cache.
"""
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
if self.stats:
self.stats.hits += 1
else:
if self.stats:
self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.")
with self._lock:
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
if self.stats:
self.stats.hits += 1
else:
if self.stats:
self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key]
cache_entry = self._cached_models[key]
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return ModelLocker(
cache=self,
cache_entry=cache_entry,
)
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return ModelLocker(
cache=self,
cache_entry=cache_entry,
)
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()