added StALKeR779's great model size calculating routine

This commit is contained in:
Lincoln Stein 2023-05-08 21:47:03 -04:00
parent c15b49c805
commit a108155544

View File

@ -21,15 +21,14 @@ import gc
import hashlib import hashlib
import warnings import warnings
from collections import Counter from collections import Counter
from enum import Enum,auto from enum import Enum
from pathlib import Path from pathlib import Path
from psutil import Process
from typing import Dict, Sequence, Union, Tuple, types from typing import Dict, Sequence, Union, Tuple, types
import torch import torch
import safetensors.torch import safetensors.torch
from diffusers import StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel
from diffusers import logging as diffusers_logging from diffusers import logging as diffusers_logging
from diffusers.pipelines.stable_diffusion.safety_checker import \ from diffusers.pipelines.stable_diffusion.safety_checker import \
StableDiffusionSafetyChecker StableDiffusionSafetyChecker
@ -84,7 +83,7 @@ class ModelStatus(Enum):
# After loading, we will know it exactly. # After loading, we will know it exactly.
# Sizes are in Gigs, estimated for float16; double for float32 # Sizes are in Gigs, estimated for float16; double for float32
SIZE_GUESSTIMATE = { SIZE_GUESSTIMATE = {
SDModelType.diffusers: 2.5, SDModelType.diffusers: 2.2,
SDModelType.vae: 0.35, SDModelType.vae: 0.35,
SDModelType.text_encoder: 0.5, SDModelType.text_encoder: 0.5,
SDModelType.tokenizer: 0.001, SDModelType.tokenizer: 0.001,
@ -255,7 +254,6 @@ class ModelCache(object):
# clean memory to make MemoryUsage() more accurate # clean memory to make MemoryUsage() more accurate
gc.collect() gc.collect()
with MemoryUsage() as usage:
model = self._load_model_from_storage( model = self._load_model_from_storage(
repo_id_or_path=repo_id_or_path, repo_id_or_path=repo_id_or_path,
model_class=model_type.value, model_class=model_type.value,
@ -263,9 +261,11 @@ class ModelCache(object):
revision=revision, revision=revision,
legacy_info=legacy_info, legacy_info=legacy_info,
) )
logger.debug(f'Actual memory used to load model: {(usage.mem_used/GIG):.2f} GB')
self.model_sizes[key] = usage.mem_used # remember size of this model for cache cleansing if mem_used := self.calc_model_size(model):
self.current_cache_size += usage.mem_used # increment size of the cache logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
self.model_sizes[key] = mem_used # remember size of this model for cache cleansing
self.current_cache_size += mem_used # increment size of the cache
# this is a bit of legacy work needed to support the old-style "load this diffuser with custom VAE" # this is a bit of legacy work needed to support the old-style "load this diffuser with custom VAE"
if model_type==SDModelType.diffusers and attach_model_part[0]: if model_type==SDModelType.diffusers and attach_model_part[0]:
@ -308,7 +308,10 @@ class ModelCache(object):
cache._offload_unlocked_models() cache._offload_unlocked_models()
if model.device != cache.execution_device: if model.device != cache.execution_device:
cache.logger.debug(f'Moving {key} into {cache.execution_device}') cache.logger.debug(f'Moving {key} into {cache.execution_device}')
with VRAMUsage() as mem:
model.to(cache.execution_device) # move into GPU model.to(cache.execution_device) # move into GPU
cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
cache.model_sizes[key] = mem.vram_used # more accurate size
cache.logger.debug(f'Locking {key} in {cache.execution_device}') cache.logger.debug(f'Locking {key} in {cache.execution_device}')
cache._print_cuda_stats() cache._print_cuda_stats()
else: else:
@ -479,8 +482,9 @@ class ModelCache(object):
''' '''
# silence transformer and diffuser warnings # silence transformer and diffuser warnings
with SilenceWarnings(): with SilenceWarnings():
# !!! NOTE: conversion should not happen here, but in ModelManager
if self.is_legacy_ckpt(repo_id_or_path): if self.is_legacy_ckpt(repo_id_or_path):
model = model_class(self._load_ckpt_from_storage(repo_id_or_path, legacy_info)) model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
else: else:
model = self._load_diffusers_from_storage( model = self._load_diffusers_from_storage(
repo_id_or_path, repo_id_or_path,
@ -608,6 +612,30 @@ class ModelCache(object):
raise KeyError(f"Revision '{revision}' not found in {repo_id}") raise KeyError(f"Revision '{revision}' not found in {repo_id}")
return desired_revisions[0].target_commit return desired_revisions[0].target_commit
@staticmethod
def calc_model_size(model)->int:
if isinstance(model,DiffusionPipeline):
return ModelCache._calc_pipeline(model)
elif isinstance(model,torch.nn.Module):
return ModelCache._calc_model(model)
else:
return None
@staticmethod
def _calc_pipeline(pipeline)->int:
res = 0
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += ModelCache._calc_model(submodel)
return res
@staticmethod
def _calc_model(model)->int:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
return mem
class SilenceWarnings(object): class SilenceWarnings(object):
def __init__(self): def __init__(self):
@ -624,14 +652,14 @@ class SilenceWarnings(object):
diffusers_logging.set_verbosity(self.diffusers_verbosity) diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter('default') warnings.simplefilter('default')
class MemoryUsage(object): class VRAMUsage(object):
def __init__(self): def __init__(self):
self.vms = None self.vram = None
self.mem_used = 0 self.vram_used = 0
def __enter__(self): def __enter__(self):
self.vms = Process().memory_info().vms self.vram = torch.cuda.memory_allocated()
return self return self
def __exit__(self, *args): def __exit__(self, *args):
self.mem_used = Process().memory_info().vms - self.vms self.vram_used = torch.cuda.memory_allocated() - self.vram