feat(mm): support cache callbacks

This commit is contained in:
psychedelicious
2025-05-15 11:23:58 +10:00
parent 8b5f4d190c
commit a33da450fd

View File

@ -4,7 +4,7 @@ import threading
import time
from functools import wraps
from logging import Logger
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Protocol
import psutil
import torch
@ -54,6 +54,22 @@ def synchronized(method: Callable[..., Any]) -> Callable[..., Any]:
return wrapper
class CacheMissCallback(Protocol):
def __call__(
self,
model_key: str,
cache_overview: dict[str, int],
) -> None: ...
class CacheHitCallback(Protocol):
def __call__(
self,
model_key: str,
cache_overview: dict[str, int],
) -> None: ...
class ModelCache:
"""A cache for managing models in memory.
@ -144,6 +160,21 @@ class ModelCache:
# - Requests to empty the cache from a separate thread
self._lock = threading.RLock()
self._on_cache_hit_callbacks: set[CacheHitCallback] = set()
self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
def register_on_cache_hit(self, cb: CacheHitCallback) -> None:
self._on_cache_hit_callbacks.add(cb)
def register_on_cache_miss(self, cb: CacheMissCallback) -> None:
self._on_cache_miss_callbacks.add(cb)
def unregister_on_cache_hit(self, cb: CacheHitCallback) -> None:
self._on_cache_hit_callbacks.discard(cb)
def unregister_on_cache_miss(self, cb: CacheMissCallback) -> None:
self._on_cache_miss_callbacks.discard(cb)
@property
@synchronized
def stats(self) -> Optional[CacheStats]:
@ -195,6 +226,15 @@ class ModelCache:
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size / MB:.2f}MB)"
)
@synchronized
def _get_cache_overview(self) -> dict[str, int]:
overview: dict[str, int] = {}
for model_key, cache_entry in self._cached_models.items():
overview[model_key] = cache_entry.cached_model.total_bytes()
# Useful? cache_entry.cached_model.is_in_vram()
return overview
@synchronized
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
"""Retrieve a model from the cache.
@ -208,6 +248,8 @@ class ModelCache:
if self.stats:
self.stats.hits += 1
else:
for cb in self._on_cache_miss_callbacks:
cb(model_key=key, cache_overview=self._get_cache_overview())
if self.stats:
self.stats.misses += 1
self._logger.debug(f"Cache miss: {key}")
@ -229,6 +271,8 @@ class ModelCache:
self._cache_stack.append(key)
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
for cb in self._on_cache_hit_callbacks:
cb(model_key=key, cache_overview=self._get_cache_overview())
return cache_entry
@synchronized