mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Report RAM usage and RAM cache statistics after each generation (#4287)
## What type of PR is this? (check all applicable) - [X] Feature ## Have you discussed this change with the InvokeAI team? - [X] Yes ## Have you updated all relevant documentation? - [X] Yes ## Description This PR enhances the logging of performance statistics to include RAM and model cache information. After each generation, the following will be logged. The new information follows TOTAL GRAPH EXECUTION TIME. ``` [2023-08-15 21:55:39,010]::[InvokeAI]::INFO --> Graph stats: 2408dbec-50d0-44a3-bbc4-427037e3f7d4 [2023-08-15 21:55:39,010]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used [2023-08-15 21:55:39,010]::[InvokeAI]::INFO --> main_model_loader 1 0.004s 0.000G [2023-08-15 21:55:39,010]::[InvokeAI]::INFO --> clip_skip 1 0.002s 0.000G [2023-08-15 21:55:39,010]::[InvokeAI]::INFO --> compel 2 2.706s 0.246G [2023-08-15 21:55:39,010]::[InvokeAI]::INFO --> rand_int 1 0.002s 0.244G [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> range_of_size 1 0.002s 0.244G [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> iterate 1 0.002s 0.244G [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.244G [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> noise 1 0.003s 0.244G [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> denoise_latents 1 2.429s 2.022G [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> l2i 1 1.020s 1.858G [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 6.171s [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> RAM used by InvokeAI process: 4.50G (delta=0.10G) [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> RAM used to load models: 1.99G [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> VRAM in use: 0.303G [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> RAM cache statistics: [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> Model cache hits: 2 [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> Model cache misses: 5 [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> Models cached: 5 [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> Models cleared from cache: 0 [2023-08-15 21:55:39,011]::[InvokeAI]::INFO --> Cache high water mark: 1.99/7.50G ``` There may be a memory leak in InvokeAI. I'm seeing the process memory usage increasing by about 100 MB with each generation as shown in the example above.
This commit is contained in:
commit
ae986bf873
@ -29,6 +29,7 @@ The abstract base class for this class is InvocationStatsServiceBase. An impleme
|
|||||||
writes to the system log is stored in InvocationServices.performance_statistics.
|
writes to the system log is stored in InvocationServices.performance_statistics.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import psutil
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import AbstractContextManager
|
from contextlib import AbstractContextManager
|
||||||
@ -42,6 +43,11 @@ import invokeai.backend.util.logging as logger
|
|||||||
from ..invocations.baseinvocation import BaseInvocation
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
from .graph import GraphExecutionState
|
from .graph import GraphExecutionState
|
||||||
from .item_storage import ItemStorageABC
|
from .item_storage import ItemStorageABC
|
||||||
|
from .model_manager_service import ModelManagerService
|
||||||
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
|
|
||||||
|
# size of GIG in bytes
|
||||||
|
GIG = 1073741824
|
||||||
|
|
||||||
|
|
||||||
class InvocationStatsServiceBase(ABC):
|
class InvocationStatsServiceBase(ABC):
|
||||||
@ -89,6 +95,8 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
invocation_type: str,
|
invocation_type: str,
|
||||||
time_used: float,
|
time_used: float,
|
||||||
vram_used: float,
|
vram_used: float,
|
||||||
|
ram_used: float,
|
||||||
|
ram_changed: float,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add timing information on execution of a node. Usually
|
Add timing information on execution of a node. Usually
|
||||||
@ -97,6 +105,8 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
:param invocation_type: String literal type of the node
|
:param invocation_type: String literal type of the node
|
||||||
:param time_used: Time used by node's exection (sec)
|
:param time_used: Time used by node's exection (sec)
|
||||||
:param vram_used: Maximum VRAM used during exection (GB)
|
:param vram_used: Maximum VRAM used during exection (GB)
|
||||||
|
:param ram_used: Current RAM available (GB)
|
||||||
|
:param ram_changed: Change in RAM usage over course of the run (GB)
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -115,6 +125,9 @@ class NodeStats:
|
|||||||
calls: int = 0
|
calls: int = 0
|
||||||
time_used: float = 0.0 # seconds
|
time_used: float = 0.0 # seconds
|
||||||
max_vram: float = 0.0 # GB
|
max_vram: float = 0.0 # GB
|
||||||
|
cache_hits: int = 0
|
||||||
|
cache_misses: int = 0
|
||||||
|
cache_high_watermark: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -133,31 +146,62 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
# {graph_id => NodeLog}
|
# {graph_id => NodeLog}
|
||||||
self._stats: Dict[str, NodeLog] = {}
|
self._stats: Dict[str, NodeLog] = {}
|
||||||
|
self._cache_stats: Dict[str, CacheStats] = {}
|
||||||
|
self.ram_used: float = 0.0
|
||||||
|
self.ram_changed: float = 0.0
|
||||||
|
|
||||||
class StatsContext:
|
class StatsContext:
|
||||||
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
|
"""Context manager for collecting statistics."""
|
||||||
|
|
||||||
|
invocation: BaseInvocation = None
|
||||||
|
collector: "InvocationStatsServiceBase" = None
|
||||||
|
graph_id: str = None
|
||||||
|
start_time: int = 0
|
||||||
|
ram_used: int = 0
|
||||||
|
model_manager: ModelManagerService = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
invocation: BaseInvocation,
|
||||||
|
graph_id: str,
|
||||||
|
model_manager: ModelManagerService,
|
||||||
|
collector: "InvocationStatsServiceBase",
|
||||||
|
):
|
||||||
|
"""Initialize statistics for this run."""
|
||||||
self.invocation = invocation
|
self.invocation = invocation
|
||||||
self.collector = collector
|
self.collector = collector
|
||||||
self.graph_id = graph_id
|
self.graph_id = graph_id
|
||||||
self.start_time = 0
|
self.start_time = 0
|
||||||
|
self.ram_used = 0
|
||||||
|
self.model_manager = model_manager
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
self.ram_used = psutil.Process().memory_info().rss
|
||||||
|
if self.model_manager:
|
||||||
|
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
|
||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
|
"""Called on exit from the context."""
|
||||||
|
ram_used = psutil.Process().memory_info().rss
|
||||||
|
self.collector.update_mem_stats(
|
||||||
|
ram_used=ram_used / GIG,
|
||||||
|
ram_changed=(ram_used - self.ram_used) / GIG,
|
||||||
|
)
|
||||||
self.collector.update_invocation_stats(
|
self.collector.update_invocation_stats(
|
||||||
self.graph_id,
|
graph_id=self.graph_id,
|
||||||
self.invocation.type,
|
invocation_type=self.invocation.type,
|
||||||
time.time() - self.start_time,
|
time_used=time.time() - self.start_time,
|
||||||
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
|
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def collect_stats(
|
def collect_stats(
|
||||||
self,
|
self,
|
||||||
invocation: BaseInvocation,
|
invocation: BaseInvocation,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
|
model_manager: ModelManagerService,
|
||||||
) -> StatsContext:
|
) -> StatsContext:
|
||||||
"""
|
"""
|
||||||
Return a context object that will capture the statistics.
|
Return a context object that will capture the statistics.
|
||||||
@ -166,7 +210,8 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
"""
|
"""
|
||||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||||
self._stats[graph_execution_state_id] = NodeLog()
|
self._stats[graph_execution_state_id] = NodeLog()
|
||||||
return self.StatsContext(invocation, graph_execution_state_id, self)
|
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||||
|
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
|
||||||
|
|
||||||
def reset_all_stats(self):
|
def reset_all_stats(self):
|
||||||
"""Zero all statistics"""
|
"""Zero all statistics"""
|
||||||
@ -179,13 +224,36 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
|
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
|
||||||
|
|
||||||
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
|
def update_mem_stats(
|
||||||
|
self,
|
||||||
|
ram_used: float,
|
||||||
|
ram_changed: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the collector with RAM memory usage info.
|
||||||
|
|
||||||
|
:param ram_used: How much RAM is currently in use.
|
||||||
|
:param ram_changed: How much RAM changed since last generation.
|
||||||
|
"""
|
||||||
|
self.ram_used = ram_used
|
||||||
|
self.ram_changed = ram_changed
|
||||||
|
|
||||||
|
def update_invocation_stats(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
invocation_type: str,
|
||||||
|
time_used: float,
|
||||||
|
vram_used: float,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Add timing information on execution of a node. Usually
|
Add timing information on execution of a node. Usually
|
||||||
used internally.
|
used internally.
|
||||||
:param graph_id: ID of the graph that is currently executing
|
:param graph_id: ID of the graph that is currently executing
|
||||||
:param invocation_type: String literal type of the node
|
:param invocation_type: String literal type of the node
|
||||||
:param time_used: Floating point seconds used by node's exection
|
:param time_used: Time used by node's exection (sec)
|
||||||
|
:param vram_used: Maximum VRAM used during exection (GB)
|
||||||
|
:param ram_used: Current RAM available (GB)
|
||||||
|
:param ram_changed: Change in RAM usage over course of the run (GB)
|
||||||
"""
|
"""
|
||||||
if not self._stats[graph_id].nodes.get(invocation_type):
|
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||||
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||||
@ -197,7 +265,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
def log_stats(self):
|
def log_stats(self):
|
||||||
"""
|
"""
|
||||||
Send the statistics to the system logger at the info level.
|
Send the statistics to the system logger at the info level.
|
||||||
Stats will only be printed if when the execution of the graph
|
Stats will only be printed when the execution of the graph
|
||||||
is complete.
|
is complete.
|
||||||
"""
|
"""
|
||||||
completed = set()
|
completed = set()
|
||||||
@ -208,16 +276,30 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
|
|
||||||
total_time = 0
|
total_time = 0
|
||||||
logger.info(f"Graph stats: {graph_id}")
|
logger.info(f"Graph stats: {graph_id}")
|
||||||
logger.info("Node Calls Seconds VRAM Used")
|
logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}")
|
||||||
for node_type, stats in self._stats[graph_id].nodes.items():
|
for node_type, stats in self._stats[graph_id].nodes.items():
|
||||||
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
|
logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G")
|
||||||
total_time += stats.time_used
|
total_time += stats.time_used
|
||||||
|
|
||||||
|
cache_stats = self._cache_stats[graph_id]
|
||||||
|
hwm = cache_stats.high_watermark / GIG
|
||||||
|
tot = cache_stats.cache_size / GIG
|
||||||
|
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
|
||||||
|
|
||||||
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
||||||
|
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
|
||||||
|
logger.info(f"RAM used to load models: {loaded:4.2f}G")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
|
logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG))
|
||||||
|
logger.info("RAM cache statistics:")
|
||||||
|
logger.info(f" Model cache hits: {cache_stats.hits}")
|
||||||
|
logger.info(f" Model cache misses: {cache_stats.misses}")
|
||||||
|
logger.info(f" Models cached: {cache_stats.in_cache}")
|
||||||
|
logger.info(f" Models cleared from cache: {cache_stats.cleared}")
|
||||||
|
logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G")
|
||||||
|
|
||||||
completed.add(graph_id)
|
completed.add(graph_id)
|
||||||
|
|
||||||
for graph_id in completed:
|
for graph_id in completed:
|
||||||
del self._stats[graph_id]
|
del self._stats[graph_id]
|
||||||
|
del self._cache_stats[graph_id]
|
||||||
|
@ -22,6 +22,7 @@ from invokeai.backend.model_management import (
|
|||||||
ModelNotFoundException,
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_management.model_search import FindModels
|
from invokeai.backend.model_management.model_search import FindModels
|
||||||
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
@ -276,6 +277,13 @@ class ModelManagerServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||||
|
"""
|
||||||
|
Reset model cache statistics for graph with graph_id.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||||
"""
|
"""
|
||||||
@ -500,6 +508,12 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
self.logger.debug(f"convert model {model_name}")
|
self.logger.debug(f"convert model {model_name}")
|
||||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||||
|
|
||||||
|
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||||
|
"""
|
||||||
|
Reset model cache statistics for graph with graph_id.
|
||||||
|
"""
|
||||||
|
self.mgr.cache.stats = cache_stats
|
||||||
|
|
||||||
def commit(self, conf_file: Optional[Path] = None):
|
def commit(self, conf_file: Optional[Path] = None):
|
||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
|
@ -86,7 +86,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
|
|
||||||
# Invoke
|
# Invoke
|
||||||
try:
|
try:
|
||||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
graph_id = graph_execution_state.id
|
||||||
|
model_manager = self.__invoker.services.model_manager
|
||||||
|
with statistics.collect_stats(invocation, graph_id, model_manager):
|
||||||
# use the internal invoke_internal(), which wraps the node's invoke() method in
|
# use the internal invoke_internal(), which wraps the node's invoke() method in
|
||||||
# this accomodates nodes which require a value, but get it only from a
|
# this accomodates nodes which require a value, but get it only from a
|
||||||
# connection
|
# connection
|
||||||
|
@ -21,12 +21,12 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import hashlib
|
import hashlib
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Union, types, Optional, Type, Any
|
from typing import Dict, Union, types, Optional, Type, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import logging
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
||||||
|
|
||||||
@ -41,6 +41,18 @@ DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
|||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheStats(object):
|
||||||
|
hits: int = 0 # cache hits
|
||||||
|
misses: int = 0 # cache misses
|
||||||
|
high_watermark: int = 0 # amount of cache used
|
||||||
|
in_cache: int = 0 # number of models in cache
|
||||||
|
cleared: int = 0 # number of models cleared to make space
|
||||||
|
cache_size: int = 0 # total size of cache
|
||||||
|
# {submodel_key => size}
|
||||||
|
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class ModelLocker(object):
|
class ModelLocker(object):
|
||||||
"Forward declaration"
|
"Forward declaration"
|
||||||
pass
|
pass
|
||||||
@ -115,6 +127,9 @@ class ModelCache(object):
|
|||||||
self.sha_chunksize = sha_chunksize
|
self.sha_chunksize = sha_chunksize
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
|
# used for stats collection
|
||||||
|
self.stats = None
|
||||||
|
|
||||||
self._cached_models = dict()
|
self._cached_models = dict()
|
||||||
self._cache_stack = list()
|
self._cache_stack = list()
|
||||||
|
|
||||||
@ -181,13 +196,14 @@ class ModelCache(object):
|
|||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel_type=submodel,
|
submodel_type=submodel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: lock for no copies on simultaneous calls?
|
# TODO: lock for no copies on simultaneous calls?
|
||||||
cache_entry = self._cached_models.get(key, None)
|
cache_entry = self._cached_models.get(key, None)
|
||||||
if cache_entry is None:
|
if cache_entry is None:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
||||||
)
|
)
|
||||||
|
if self.stats:
|
||||||
|
self.stats.misses += 1
|
||||||
|
|
||||||
# this will remove older cached models until
|
# this will remove older cached models until
|
||||||
# there is sufficient room to load the requested model
|
# there is sufficient room to load the requested model
|
||||||
@ -201,6 +217,17 @@ class ModelCache(object):
|
|||||||
|
|
||||||
cache_entry = _CacheRecord(self, model, mem_used)
|
cache_entry = _CacheRecord(self, model, mem_used)
|
||||||
self._cached_models[key] = cache_entry
|
self._cached_models[key] = cache_entry
|
||||||
|
else:
|
||||||
|
if self.stats:
|
||||||
|
self.stats.hits += 1
|
||||||
|
|
||||||
|
if self.stats:
|
||||||
|
self.stats.cache_size = self.max_cache_size * GIG
|
||||||
|
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
|
||||||
|
self.stats.in_cache = len(self._cached_models)
|
||||||
|
self.stats.loaded_model_sizes[key] = max(
|
||||||
|
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
|
||||||
|
)
|
||||||
|
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
self._cache_stack.remove(key)
|
self._cache_stack.remove(key)
|
||||||
@ -280,14 +307,14 @@ class ModelCache(object):
|
|||||||
"""
|
"""
|
||||||
Given the HF repo id or path to a model on disk, returns a unique
|
Given the HF repo id or path to a model on disk, returns a unique
|
||||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||||
|
|
||||||
:param model_path: Path to model file/directory on disk.
|
:param model_path: Path to model file/directory on disk.
|
||||||
"""
|
"""
|
||||||
return self._local_model_hash(model_path)
|
return self._local_model_hash(model_path)
|
||||||
|
|
||||||
def cache_size(self) -> float:
|
def cache_size(self) -> float:
|
||||||
"Return the current size of the cache, in GB"
|
"""Return the current size of the cache, in GB."""
|
||||||
current_cache_size = sum([m.size for m in self._cached_models.values()])
|
return self._cache_size() / GIG
|
||||||
return current_cache_size / GIG
|
|
||||||
|
|
||||||
def _has_cuda(self) -> bool:
|
def _has_cuda(self) -> bool:
|
||||||
return self.execution_device.type == "cuda"
|
return self.execution_device.type == "cuda"
|
||||||
@ -310,12 +337,15 @@ class ModelCache(object):
|
|||||||
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}"
|
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _cache_size(self) -> int:
|
||||||
|
return sum([m.size for m in self._cached_models.values()])
|
||||||
|
|
||||||
def _make_cache_room(self, model_size):
|
def _make_cache_room(self, model_size):
|
||||||
# calculate how much memory this model will require
|
# calculate how much memory this model will require
|
||||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||||
bytes_needed = model_size
|
bytes_needed = model_size
|
||||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||||
current_size = sum([m.size for m in self._cached_models.values()])
|
current_size = self._cache_size()
|
||||||
|
|
||||||
if current_size + bytes_needed > maximum_size:
|
if current_size + bytes_needed > maximum_size:
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
@ -364,6 +394,8 @@ class ModelCache(object):
|
|||||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||||
)
|
)
|
||||||
current_size -= cache_entry.size
|
current_size -= cache_entry.size
|
||||||
|
if self.stats:
|
||||||
|
self.stats.cleared += 1
|
||||||
del self._cache_stack[pos]
|
del self._cache_stack[pos]
|
||||||
del self._cached_models[model_key]
|
del self._cached_models[model_key]
|
||||||
del cache_entry
|
del cache_entry
|
||||||
|
Loading…
Reference in New Issue
Block a user