mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
revert object_serializer_forward_cache.py
This commit is contained in:
parent
f7436f3bae
commit
a84f3058e2
@ -1,4 +1,3 @@
|
|||||||
import threading
|
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||||
|
|
||||||
@ -19,8 +18,8 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
|
|||||||
def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20):
|
def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._underlying_storage = underlying_storage
|
self._underlying_storage = underlying_storage
|
||||||
self._cache: dict[int, dict[str, T]] = {}
|
self._cache: dict[str, T] = {}
|
||||||
self._cache_ids: dict[int, Queue[str]] = {}
|
self._cache_ids = Queue[str]()
|
||||||
self._max_cache_size = max_cache_size
|
self._max_cache_size = max_cache_size
|
||||||
|
|
||||||
def start(self, invoker: "Invoker") -> None:
|
def start(self, invoker: "Invoker") -> None:
|
||||||
@ -55,27 +54,12 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
|
|||||||
del self._cache[name]
|
del self._cache[name]
|
||||||
self._on_deleted(name)
|
self._on_deleted(name)
|
||||||
|
|
||||||
def _get_tid_cache(self) -> dict[str, T]:
|
|
||||||
tid = threading.current_thread().ident
|
|
||||||
if tid not in self._cache:
|
|
||||||
self._cache[tid] = {}
|
|
||||||
return self._cache[tid]
|
|
||||||
|
|
||||||
def _get_tid_cache_ids(self) -> Queue[str]:
|
|
||||||
tid = threading.current_thread().ident
|
|
||||||
if tid not in self._cache_ids:
|
|
||||||
self._cache_ids[tid] = Queue[str]()
|
|
||||||
return self._cache_ids[tid]
|
|
||||||
|
|
||||||
def _get_cache(self, name: str) -> Optional[T]:
|
def _get_cache(self, name: str) -> Optional[T]:
|
||||||
cache = self._get_tid_cache()
|
return None if name not in self._cache else self._cache[name]
|
||||||
return None if name not in cache else cache[name]
|
|
||||||
|
|
||||||
def _set_cache(self, name: str, data: T):
|
def _set_cache(self, name: str, data: T):
|
||||||
cache = self._get_tid_cache()
|
if name not in self._cache:
|
||||||
if name not in cache:
|
self._cache[name] = data
|
||||||
cache[name] = data
|
self._cache_ids.put(name)
|
||||||
cache_ids = self._get_tid_cache_ids()
|
if self._cache_ids.qsize() > self._max_cache_size:
|
||||||
cache_ids.put(name)
|
self._cache.pop(self._cache_ids.get())
|
||||||
if cache_ids.qsize() > self._max_cache_size:
|
|
||||||
cache.pop(cache_ids.get())
|
|
||||||
|
@ -156,7 +156,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
device = free_device[0]
|
device = free_device[0]
|
||||||
|
|
||||||
# we are outside the lock region now
|
# we are outside the lock region now
|
||||||
self.logger.info("Reserved torch device {device} for execution thread {current_thread}")
|
self.logger.info(f"Reserved torch device {device} for execution thread {current_thread}")
|
||||||
|
|
||||||
# Tell TorchDevice to use this object to get the torch device.
|
# Tell TorchDevice to use this object to get the torch device.
|
||||||
TorchDevice.set_model_cache(self)
|
TorchDevice.set_model_cache(self)
|
||||||
@ -164,7 +164,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
yield device
|
yield device
|
||||||
finally:
|
finally:
|
||||||
with self._device_lock:
|
with self._device_lock:
|
||||||
self.logger.info("Released torch device {device}")
|
self.logger.info(f"Released torch device {device}")
|
||||||
self._execution_devices[device] = 0
|
self._execution_devices[device] = 0
|
||||||
self._free_execution_device.release()
|
self._free_execution_device.release()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
Loading…
Reference in New Issue
Block a user