revert object_serializer_forward_cache.py

This commit is contained in:
Lincoln Stein 2024-04-15 22:28:48 -04:00
parent f7436f3bae
commit a84f3058e2
2 changed files with 10 additions and 26 deletions

View File

@ -1,4 +1,3 @@
import threading
from queue import Queue from queue import Queue
from typing import TYPE_CHECKING, Optional, TypeVar from typing import TYPE_CHECKING, Optional, TypeVar
@ -19,8 +18,8 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20): def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20):
super().__init__() super().__init__()
self._underlying_storage = underlying_storage self._underlying_storage = underlying_storage
self._cache: dict[int, dict[str, T]] = {} self._cache: dict[str, T] = {}
self._cache_ids: dict[int, Queue[str]] = {} self._cache_ids = Queue[str]()
self._max_cache_size = max_cache_size self._max_cache_size = max_cache_size
def start(self, invoker: "Invoker") -> None: def start(self, invoker: "Invoker") -> None:
@ -55,27 +54,12 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
del self._cache[name] del self._cache[name]
self._on_deleted(name) self._on_deleted(name)
def _get_tid_cache(self) -> dict[str, T]:
tid = threading.current_thread().ident
if tid not in self._cache:
self._cache[tid] = {}
return self._cache[tid]
def _get_tid_cache_ids(self) -> Queue[str]:
tid = threading.current_thread().ident
if tid not in self._cache_ids:
self._cache_ids[tid] = Queue[str]()
return self._cache_ids[tid]
def _get_cache(self, name: str) -> Optional[T]: def _get_cache(self, name: str) -> Optional[T]:
cache = self._get_tid_cache() return None if name not in self._cache else self._cache[name]
return None if name not in cache else cache[name]
def _set_cache(self, name: str, data: T): def _set_cache(self, name: str, data: T):
cache = self._get_tid_cache() if name not in self._cache:
if name not in cache: self._cache[name] = data
cache[name] = data self._cache_ids.put(name)
cache_ids = self._get_tid_cache_ids() if self._cache_ids.qsize() > self._max_cache_size:
cache_ids.put(name) self._cache.pop(self._cache_ids.get())
if cache_ids.qsize() > self._max_cache_size:
cache.pop(cache_ids.get())

View File

@ -156,7 +156,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
device = free_device[0] device = free_device[0]
# we are outside the lock region now # we are outside the lock region now
self.logger.info("Reserved torch device {device} for execution thread {current_thread}") self.logger.info(f"Reserved torch device {device} for execution thread {current_thread}")
# Tell TorchDevice to use this object to get the torch device. # Tell TorchDevice to use this object to get the torch device.
TorchDevice.set_model_cache(self) TorchDevice.set_model_cache(self)
@ -164,7 +164,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
yield device yield device
finally: finally:
with self._device_lock: with self._device_lock:
self.logger.info("Released torch device {device}") self.logger.info(f"Released torch device {device}")
self._execution_devices[device] = 0 self._execution_devices[device] = 0
self._free_execution_device.release() self._free_execution_device.release()
torch.cuda.empty_cache() torch.cuda.empty_cache()