mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add locking around thread-critical sections
This commit is contained in:
parent
a1dcab9c38
commit
9336a076de
@ -20,6 +20,7 @@ context. Use like this:
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from threading import BoundedSemaphore, Lock
|
from threading import BoundedSemaphore, Lock
|
||||||
@ -80,6 +81,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self._max_cache_size: float = max_cache_size
|
self._max_cache_size: float = max_cache_size
|
||||||
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
|
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
|
||||||
self._storage_device: torch.device = storage_device
|
self._storage_device: torch.device = storage_device
|
||||||
|
self._lock = threading.Lock()
|
||||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||||
self._log_memory_usage = log_memory_usage
|
self._log_memory_usage = log_memory_usage
|
||||||
self._stats: Optional[CacheStats] = None
|
self._stats: Optional[CacheStats] = None
|
||||||
@ -162,6 +164,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store model under key and optional submodel_type."""
|
"""Store model under key and optional submodel_type."""
|
||||||
|
with self._lock:
|
||||||
key = self._make_cache_key(key, submodel_type)
|
key = self._make_cache_key(key, submodel_type)
|
||||||
assert key not in self._cached_models
|
assert key not in self._cached_models
|
||||||
|
|
||||||
@ -185,6 +188,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
|
|
||||||
This may raise an IndexError if the model is not in the cache.
|
This may raise an IndexError if the model is not in the cache.
|
||||||
"""
|
"""
|
||||||
|
with self._lock:
|
||||||
key = self._make_cache_key(key, submodel_type)
|
key = self._make_cache_key(key, submodel_type)
|
||||||
if key in self._cached_models:
|
if key in self._cached_models:
|
||||||
if self.stats:
|
if self.stats:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user