mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 23:58:14 +00:00
feat(mm): support cache callbacks
This commit is contained in:
@ -4,7 +4,7 @@ import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from logging import Logger
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
@ -54,6 +54,22 @@ def synchronized(method: Callable[..., Any]) -> Callable[..., Any]:
|
||||
return wrapper
|
||||
|
||||
|
||||
class CacheMissCallback(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
model_key: str,
|
||||
cache_overview: dict[str, int],
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class CacheHitCallback(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
model_key: str,
|
||||
cache_overview: dict[str, int],
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class ModelCache:
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
@ -144,6 +160,21 @@ class ModelCache:
|
||||
# - Requests to empty the cache from a separate thread
|
||||
self._lock = threading.RLock()
|
||||
|
||||
self._on_cache_hit_callbacks: set[CacheHitCallback] = set()
|
||||
self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
|
||||
|
||||
def register_on_cache_hit(self, cb: CacheHitCallback) -> None:
|
||||
self._on_cache_hit_callbacks.add(cb)
|
||||
|
||||
def register_on_cache_miss(self, cb: CacheMissCallback) -> None:
|
||||
self._on_cache_miss_callbacks.add(cb)
|
||||
|
||||
def unregister_on_cache_hit(self, cb: CacheHitCallback) -> None:
|
||||
self._on_cache_hit_callbacks.discard(cb)
|
||||
|
||||
def unregister_on_cache_miss(self, cb: CacheMissCallback) -> None:
|
||||
self._on_cache_miss_callbacks.discard(cb)
|
||||
|
||||
@property
|
||||
@synchronized
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
@ -195,6 +226,15 @@ class ModelCache:
|
||||
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size / MB:.2f}MB)"
|
||||
)
|
||||
|
||||
@synchronized
|
||||
def _get_cache_overview(self) -> dict[str, int]:
|
||||
overview: dict[str, int] = {}
|
||||
for model_key, cache_entry in self._cached_models.items():
|
||||
overview[model_key] = cache_entry.cached_model.total_bytes()
|
||||
# Useful? cache_entry.cached_model.is_in_vram()
|
||||
|
||||
return overview
|
||||
|
||||
@synchronized
|
||||
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
|
||||
"""Retrieve a model from the cache.
|
||||
@ -208,6 +248,8 @@ class ModelCache:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
for cb in self._on_cache_miss_callbacks:
|
||||
cb(model_key=key, cache_overview=self._get_cache_overview())
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
self._logger.debug(f"Cache miss: {key}")
|
||||
@ -229,6 +271,8 @@ class ModelCache:
|
||||
self._cache_stack.append(key)
|
||||
|
||||
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
for cb in self._on_cache_hit_callbacks:
|
||||
cb(model_key=key, cache_overview=self._get_cache_overview())
|
||||
return cache_entry
|
||||
|
||||
@synchronized
|
||||
|
Reference in New Issue
Block a user