From 9336a076deb2b9632f3b29c478e10ea6922e652b Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 31 Mar 2024 16:58:56 -0400 Subject: [PATCH] add locking around thread-critical sections --- .../load/model_cache/model_cache_default.py | 66 ++++++++++--------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 82935ef786..90090b522d 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -20,6 +20,7 @@ context. Use like this: import gc import sys +import threading from contextlib import suppress from logging import Logger from threading import BoundedSemaphore, Lock @@ -80,6 +81,7 @@ class ModelCache(ModelCacheBase[AnyModel]): self._max_cache_size: float = max_cache_size self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices() self._storage_device: torch.device = storage_device + self._lock = threading.Lock() self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage self._stats: Optional[CacheStats] = None @@ -162,12 +164,13 @@ class ModelCache(ModelCacheBase[AnyModel]): submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" - key = self._make_cache_key(key, submodel_type) - assert key not in self._cached_models + with self._lock: + key = self._make_cache_key(key, submodel_type) + assert key not in self._cached_models - cache_record = CacheRecord(key=key, model=model, size=size) - self._cached_models[key] = cache_record - self._cache_stack.append(key) + cache_record = CacheRecord(key=key, model=model, size=size) + self._cached_models[key] = cache_record + self._cache_stack.append(key) def get( self, @@ -185,36 +188,37 @@ class ModelCache(ModelCacheBase[AnyModel]): This may raise an IndexError if the model is not in the cache. """ - key = self._make_cache_key(key, submodel_type) - if key in self._cached_models: - if self.stats: - self.stats.hits += 1 - else: - if self.stats: - self.stats.misses += 1 - raise IndexError(f"The model with key {key} is not in the cache.") + with self._lock: + key = self._make_cache_key(key, submodel_type) + if key in self._cached_models: + if self.stats: + self.stats.hits += 1 + else: + if self.stats: + self.stats.misses += 1 + raise IndexError(f"The model with key {key} is not in the cache.") - cache_entry = self._cached_models[key] + cache_entry = self._cached_models[key] - # more stats - if self.stats: - stats_name = stats_name or key - self.stats.cache_size = int(self._max_cache_size * GIG) - self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) - self.stats.in_cache = len(self._cached_models) - self.stats.loaded_model_sizes[stats_name] = max( - self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size + # more stats + if self.stats: + stats_name = stats_name or key + self.stats.cache_size = int(self._max_cache_size * GIG) + self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) + self.stats.in_cache = len(self._cached_models) + self.stats.loaded_model_sizes[stats_name] = max( + self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size + ) + + # this moves the entry to the top (right end) of the stack + with suppress(Exception): + self._cache_stack.remove(key) + self._cache_stack.append(key) + return ModelLocker( + cache=self, + cache_entry=cache_entry, ) - # this moves the entry to the top (right end) of the stack - with suppress(Exception): - self._cache_stack.remove(key) - self._cache_stack.append(key) - return ModelLocker( - cache=self, - cache_entry=cache_entry, - ) - def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: if self._log_memory_usage: return MemorySnapshot.capture()