mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add locking around thread-critical sections
This commit is contained in:
parent
a1dcab9c38
commit
9336a076de
@ -20,6 +20,7 @@ context. Use like this:
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from threading import BoundedSemaphore, Lock
|
from threading import BoundedSemaphore, Lock
|
||||||
@ -80,6 +81,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self._max_cache_size: float = max_cache_size
|
self._max_cache_size: float = max_cache_size
|
||||||
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
|
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
|
||||||
self._storage_device: torch.device = storage_device
|
self._storage_device: torch.device = storage_device
|
||||||
|
self._lock = threading.Lock()
|
||||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||||
self._log_memory_usage = log_memory_usage
|
self._log_memory_usage = log_memory_usage
|
||||||
self._stats: Optional[CacheStats] = None
|
self._stats: Optional[CacheStats] = None
|
||||||
@ -162,12 +164,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store model under key and optional submodel_type."""
|
"""Store model under key and optional submodel_type."""
|
||||||
key = self._make_cache_key(key, submodel_type)
|
with self._lock:
|
||||||
assert key not in self._cached_models
|
key = self._make_cache_key(key, submodel_type)
|
||||||
|
assert key not in self._cached_models
|
||||||
|
|
||||||
cache_record = CacheRecord(key=key, model=model, size=size)
|
cache_record = CacheRecord(key=key, model=model, size=size)
|
||||||
self._cached_models[key] = cache_record
|
self._cached_models[key] = cache_record
|
||||||
self._cache_stack.append(key)
|
self._cache_stack.append(key)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
@ -185,36 +188,37 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
|
|
||||||
This may raise an IndexError if the model is not in the cache.
|
This may raise an IndexError if the model is not in the cache.
|
||||||
"""
|
"""
|
||||||
key = self._make_cache_key(key, submodel_type)
|
with self._lock:
|
||||||
if key in self._cached_models:
|
key = self._make_cache_key(key, submodel_type)
|
||||||
if self.stats:
|
if key in self._cached_models:
|
||||||
self.stats.hits += 1
|
if self.stats:
|
||||||
else:
|
self.stats.hits += 1
|
||||||
if self.stats:
|
else:
|
||||||
self.stats.misses += 1
|
if self.stats:
|
||||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
self.stats.misses += 1
|
||||||
|
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||||
|
|
||||||
cache_entry = self._cached_models[key]
|
cache_entry = self._cached_models[key]
|
||||||
|
|
||||||
# more stats
|
# more stats
|
||||||
if self.stats:
|
if self.stats:
|
||||||
stats_name = stats_name or key
|
stats_name = stats_name or key
|
||||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||||
self.stats.in_cache = len(self._cached_models)
|
self.stats.in_cache = len(self._cached_models)
|
||||||
self.stats.loaded_model_sizes[stats_name] = max(
|
self.stats.loaded_model_sizes[stats_name] = max(
|
||||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||||
|
)
|
||||||
|
|
||||||
|
# this moves the entry to the top (right end) of the stack
|
||||||
|
with suppress(Exception):
|
||||||
|
self._cache_stack.remove(key)
|
||||||
|
self._cache_stack.append(key)
|
||||||
|
return ModelLocker(
|
||||||
|
cache=self,
|
||||||
|
cache_entry=cache_entry,
|
||||||
)
|
)
|
||||||
|
|
||||||
# this moves the entry to the top (right end) of the stack
|
|
||||||
with suppress(Exception):
|
|
||||||
self._cache_stack.remove(key)
|
|
||||||
self._cache_stack.append(key)
|
|
||||||
return ModelLocker(
|
|
||||||
cache=self,
|
|
||||||
cache_entry=cache_entry,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||||
if self._log_memory_usage:
|
if self._log_memory_usage:
|
||||||
return MemorySnapshot.capture()
|
return MemorySnapshot.capture()
|
||||||
|
Loading…
Reference in New Issue
Block a user