add locking around thread-critical sections

This commit is contained in:
Lincoln Stein 2024-03-31 16:58:56 -04:00
parent a1dcab9c38
commit 9336a076de

View File

@ -20,6 +20,7 @@ context. Use like this:
import gc import gc
import sys import sys
import threading
from contextlib import suppress from contextlib import suppress
from logging import Logger from logging import Logger
from threading import BoundedSemaphore, Lock from threading import BoundedSemaphore, Lock
@ -80,6 +81,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._max_cache_size: float = max_cache_size self._max_cache_size: float = max_cache_size
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices() self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
self._storage_device: torch.device = storage_device self._storage_device: torch.device = storage_device
self._lock = threading.Lock()
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None self._stats: Optional[CacheStats] = None
@ -162,12 +164,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> None: ) -> None:
"""Store model under key and optional submodel_type.""" """Store model under key and optional submodel_type."""
key = self._make_cache_key(key, submodel_type) with self._lock:
assert key not in self._cached_models key = self._make_cache_key(key, submodel_type)
assert key not in self._cached_models
cache_record = CacheRecord(key=key, model=model, size=size) cache_record = CacheRecord(key=key, model=model, size=size)
self._cached_models[key] = cache_record self._cached_models[key] = cache_record
self._cache_stack.append(key) self._cache_stack.append(key)
def get( def get(
self, self,
@ -185,36 +188,37 @@ class ModelCache(ModelCacheBase[AnyModel]):
This may raise an IndexError if the model is not in the cache. This may raise an IndexError if the model is not in the cache.
""" """
key = self._make_cache_key(key, submodel_type) with self._lock:
if key in self._cached_models: key = self._make_cache_key(key, submodel_type)
if self.stats: if key in self._cached_models:
self.stats.hits += 1 if self.stats:
else: self.stats.hits += 1
if self.stats: else:
self.stats.misses += 1 if self.stats:
raise IndexError(f"The model with key {key} is not in the cache.") self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key] cache_entry = self._cached_models[key]
# more stats # more stats
if self.stats: if self.stats:
stats_name = stats_name or key stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG) self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models) self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max( self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return ModelLocker(
cache=self,
cache_entry=cache_entry,
) )
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return ModelLocker(
cache=self,
cache_entry=cache_entry,
)
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage: if self._log_memory_usage:
return MemorySnapshot.capture() return MemorySnapshot.capture()