fix(mm): typing issues in model cache

This commit is contained in:
psychedelicious 2024-04-05 18:41:10 +11:00
parent a09d705e4c
commit 4068e817d6
2 changed files with 4 additions and 4 deletions

View File

@ -117,7 +117,7 @@ class ModelCacheBase(ABC, Generic[T]):
@property @property
@abstractmethod @abstractmethod
def stats(self) -> CacheStats: def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object.""" """Return collected CacheStats object."""
pass pass

View File

@ -326,11 +326,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
) )
def make_room(self, model_size: int) -> None: def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size.""" """Make enough room in the cache to accommodate a new model of indicated size."""
# calculate how much memory this model will require # calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1 # multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size bytes_needed = size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = self.cache_size() current_size = self.cache_size()
@ -385,7 +385,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
# 1 from onnx runtime object # 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug( self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
) )
current_size -= cache_entry.size current_size -= cache_entry.size
models_cleared += 1 models_cleared += 1