add locking around thread-critical sections

This commit is contained in:
Lincoln Stein 2024-03-31 16:58:56 -04:00
parent a1dcab9c38
commit 9336a076de

View File

@ -20,6 +20,7 @@ context. Use like this:
import gc
import sys
import threading
from contextlib import suppress
from logging import Logger
from threading import BoundedSemaphore, Lock
@ -80,6 +81,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._max_cache_size: float = max_cache_size
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
self._storage_device: torch.device = storage_device
self._lock = threading.Lock()
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None
@ -162,6 +164,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
with self._lock:
key = self._make_cache_key(key, submodel_type)
assert key not in self._cached_models
@ -185,6 +188,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
This may raise an IndexError if the model is not in the cache.
"""
with self._lock:
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
if self.stats: