From a84f3058e2119b7b54ea16ae6730ada655fb0427 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 15 Apr 2024 22:28:48 -0400 Subject: [PATCH] revert object_serializer_forward_cache.py --- .../object_serializer_forward_cache.py | 32 +++++-------------- .../load/model_cache/model_cache_default.py | 4 +-- 2 files changed, 10 insertions(+), 26 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index bf16bfe242..b361259a4b 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -1,4 +1,3 @@ -import threading from queue import Queue from typing import TYPE_CHECKING, Optional, TypeVar @@ -19,8 +18,8 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]): def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20): super().__init__() self._underlying_storage = underlying_storage - self._cache: dict[int, dict[str, T]] = {} - self._cache_ids: dict[int, Queue[str]] = {} + self._cache: dict[str, T] = {} + self._cache_ids = Queue[str]() self._max_cache_size = max_cache_size def start(self, invoker: "Invoker") -> None: @@ -55,27 +54,12 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]): del self._cache[name] self._on_deleted(name) - def _get_tid_cache(self) -> dict[str, T]: - tid = threading.current_thread().ident - if tid not in self._cache: - self._cache[tid] = {} - return self._cache[tid] - - def _get_tid_cache_ids(self) -> Queue[str]: - tid = threading.current_thread().ident - if tid not in self._cache_ids: - self._cache_ids[tid] = Queue[str]() - return self._cache_ids[tid] - def _get_cache(self, name: str) -> Optional[T]: - cache = self._get_tid_cache() - return None if name not in cache else cache[name] + return None if name not in self._cache else self._cache[name] def _set_cache(self, name: str, data: T): - cache = self._get_tid_cache() - if name not in cache: - cache[name] = data - cache_ids = self._get_tid_cache_ids() - cache_ids.put(name) - if cache_ids.qsize() > self._max_cache_size: - cache.pop(cache_ids.get()) + if name not in self._cache: + self._cache[name] = data + self._cache_ids.put(name) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index f7f466f2b0..026bb8aec5 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -156,7 +156,7 @@ class ModelCache(ModelCacheBase[AnyModel]): device = free_device[0] # we are outside the lock region now - self.logger.info("Reserved torch device {device} for execution thread {current_thread}") + self.logger.info(f"Reserved torch device {device} for execution thread {current_thread}") # Tell TorchDevice to use this object to get the torch device. TorchDevice.set_model_cache(self) @@ -164,7 +164,7 @@ class ModelCache(ModelCacheBase[AnyModel]): yield device finally: with self._device_lock: - self.logger.info("Released torch device {device}") + self.logger.info(f"Released torch device {device}") self._execution_devices[device] = 0 self._free_execution_device.release() torch.cuda.empty_cache()