mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
revert object_serializer_forward_cache.py
This commit is contained in:
parent
f7436f3bae
commit
a84f3058e2
@ -1,4 +1,3 @@
|
||||
import threading
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
|
||||
@ -19,8 +18,8 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
|
||||
def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self._underlying_storage = underlying_storage
|
||||
self._cache: dict[int, dict[str, T]] = {}
|
||||
self._cache_ids: dict[int, Queue[str]] = {}
|
||||
self._cache: dict[str, T] = {}
|
||||
self._cache_ids = Queue[str]()
|
||||
self._max_cache_size = max_cache_size
|
||||
|
||||
def start(self, invoker: "Invoker") -> None:
|
||||
@ -55,27 +54,12 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
|
||||
del self._cache[name]
|
||||
self._on_deleted(name)
|
||||
|
||||
def _get_tid_cache(self) -> dict[str, T]:
|
||||
tid = threading.current_thread().ident
|
||||
if tid not in self._cache:
|
||||
self._cache[tid] = {}
|
||||
return self._cache[tid]
|
||||
|
||||
def _get_tid_cache_ids(self) -> Queue[str]:
|
||||
tid = threading.current_thread().ident
|
||||
if tid not in self._cache_ids:
|
||||
self._cache_ids[tid] = Queue[str]()
|
||||
return self._cache_ids[tid]
|
||||
|
||||
def _get_cache(self, name: str) -> Optional[T]:
|
||||
cache = self._get_tid_cache()
|
||||
return None if name not in cache else cache[name]
|
||||
return None if name not in self._cache else self._cache[name]
|
||||
|
||||
def _set_cache(self, name: str, data: T):
|
||||
cache = self._get_tid_cache()
|
||||
if name not in cache:
|
||||
cache[name] = data
|
||||
cache_ids = self._get_tid_cache_ids()
|
||||
cache_ids.put(name)
|
||||
if cache_ids.qsize() > self._max_cache_size:
|
||||
cache.pop(cache_ids.get())
|
||||
if name not in self._cache:
|
||||
self._cache[name] = data
|
||||
self._cache_ids.put(name)
|
||||
if self._cache_ids.qsize() > self._max_cache_size:
|
||||
self._cache.pop(self._cache_ids.get())
|
||||
|
@ -156,7 +156,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
device = free_device[0]
|
||||
|
||||
# we are outside the lock region now
|
||||
self.logger.info("Reserved torch device {device} for execution thread {current_thread}")
|
||||
self.logger.info(f"Reserved torch device {device} for execution thread {current_thread}")
|
||||
|
||||
# Tell TorchDevice to use this object to get the torch device.
|
||||
TorchDevice.set_model_cache(self)
|
||||
@ -164,7 +164,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
yield device
|
||||
finally:
|
||||
with self._device_lock:
|
||||
self.logger.info("Released torch device {device}")
|
||||
self.logger.info(f"Released torch device {device}")
|
||||
self._execution_devices[device] = 0
|
||||
self._free_execution_device.release()
|
||||
torch.cuda.empty_cache()
|
||||
|
Loading…
Reference in New Issue
Block a user