add locking around thread-critical sections

This commit is contained in:
Lincoln Stein 2024-03-31 16:58:56 -04:00
parent a1dcab9c38
commit 9336a076de

View File

@ -20,6 +20,7 @@ context. Use like this:
import gc import gc
import sys import sys
import threading
from contextlib import suppress from contextlib import suppress
from logging import Logger from logging import Logger
from threading import BoundedSemaphore, Lock from threading import BoundedSemaphore, Lock
@ -80,6 +81,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._max_cache_size: float = max_cache_size self._max_cache_size: float = max_cache_size
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices() self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
self._storage_device: torch.device = storage_device self._storage_device: torch.device = storage_device
self._lock = threading.Lock()
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None self._stats: Optional[CacheStats] = None
@ -162,6 +164,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> None: ) -> None:
"""Store model under key and optional submodel_type.""" """Store model under key and optional submodel_type."""
with self._lock:
key = self._make_cache_key(key, submodel_type) key = self._make_cache_key(key, submodel_type)
assert key not in self._cached_models assert key not in self._cached_models
@ -185,6 +188,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
This may raise an IndexError if the model is not in the cache. This may raise an IndexError if the model is not in the cache.
""" """
with self._lock:
key = self._make_cache_key(key, submodel_type) key = self._make_cache_key(key, submodel_type)
if key in self._cached_models: if key in self._cached_models:
if self.stats: if self.stats: