2023-04-28 04:41:52 +00:00
|
|
|
"""
|
2023-05-03 16:38:18 +00:00
|
|
|
Manage a RAM cache of diffusion/transformer models for fast switching.
|
|
|
|
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
2023-04-28 04:41:52 +00:00
|
|
|
grows larger than a preset maximum, then the least recently used
|
|
|
|
model will be cleared and (re)loaded from disk when next needed.
|
2023-05-03 16:38:18 +00:00
|
|
|
|
|
|
|
The cache returns context manager generators designed to load the
|
|
|
|
model into the GPU within the context, and unload outside the
|
|
|
|
context. Use like this:
|
|
|
|
|
|
|
|
cache = ModelCache(max_models_cached=6)
|
|
|
|
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
|
|
|
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
|
|
|
do_something_in_GPU(SD1,SD2)
|
|
|
|
|
|
|
|
|
2023-04-28 04:41:52 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
import contextlib
|
|
|
|
import gc
|
2023-05-18 00:56:52 +00:00
|
|
|
import os
|
|
|
|
import sys
|
2023-05-03 16:38:18 +00:00
|
|
|
import hashlib
|
2023-04-28 04:41:52 +00:00
|
|
|
import warnings
|
2023-05-14 22:09:38 +00:00
|
|
|
from contextlib import suppress
|
2023-05-09 01:47:03 +00:00
|
|
|
from enum import Enum
|
2023-04-28 04:41:52 +00:00
|
|
|
from pathlib import Path
|
2023-05-18 00:56:52 +00:00
|
|
|
from typing import Dict, Sequence, Union, Tuple, types, Optional, List, Type, Any
|
2023-04-28 04:41:52 +00:00
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
import torch
|
2023-05-06 19:58:44 +00:00
|
|
|
import safetensors.torch
|
2023-05-18 00:56:52 +00:00
|
|
|
|
|
|
|
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
|
2023-05-03 16:38:18 +00:00
|
|
|
from diffusers import logging as diffusers_logging
|
|
|
|
from huggingface_hub import HfApi
|
|
|
|
from picklescan.scanner import scan_file_path
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from transformers import logging as transformers_logging
|
|
|
|
|
2023-05-05 03:15:32 +00:00
|
|
|
import invokeai.backend.util.logging as logger
|
2023-05-03 16:38:18 +00:00
|
|
|
from ..globals import global_cache_dir
|
2023-04-28 04:41:52 +00:00
|
|
|
|
2023-05-07 22:07:28 +00:00
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
def get_model_path(repo_id_or_path: str):
|
|
|
|
if os.path.exists(repo_id_or_path):
|
|
|
|
return repo_id_or_path
|
|
|
|
|
|
|
|
cache = scan_cache_dir(global_cache_dir("hub"))
|
|
|
|
for repo in cache.repos:
|
|
|
|
if repo.repo_id != repo_id_or_path:
|
|
|
|
continue
|
|
|
|
for rev in repo.revisions:
|
|
|
|
if "main" in rev.refs:
|
|
|
|
return rev.snapshot_path
|
|
|
|
raise Exception(f"{repo_id_or_path} - not found")
|
|
|
|
|
|
|
|
def calc_model_size_by_fs(
|
|
|
|
repo_id_or_path: str,
|
|
|
|
subfolder: Optional[str] = None,
|
|
|
|
variant: Optional[str] = None
|
|
|
|
):
|
|
|
|
model_path = get_model_path(repo_id_or_path)
|
|
|
|
if subfolder is not None:
|
|
|
|
model_path = os.path.join(model_path, subfolder)
|
|
|
|
|
|
|
|
all_files = os.listdir(model_path)
|
|
|
|
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
|
|
|
|
|
|
|
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
|
|
|
|
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
|
|
|
|
other_files = set(all_files) - fp16_files - bit8_files
|
|
|
|
|
|
|
|
if variant is None:
|
|
|
|
files = other_files
|
|
|
|
elif variant == "fp16":
|
|
|
|
files = fp16_files
|
|
|
|
elif variant == "8bit":
|
|
|
|
files = bit8_files
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"Unknown variant: {variant}")
|
|
|
|
|
|
|
|
# try read from index if exists
|
|
|
|
index_postfix = ".index.json"
|
|
|
|
if variant is not None:
|
|
|
|
index_postfix = f".index.{variant}.json"
|
|
|
|
|
|
|
|
for file in files:
|
|
|
|
if not file.endswith(index_postfix):
|
|
|
|
continue
|
|
|
|
try:
|
|
|
|
with open(os.path.join(model_path, index_file), "r") as f:
|
|
|
|
index_data = json.loads(f.read())
|
|
|
|
return int(index_data["metadata"]["total_size"])
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
|
|
|
|
# calculate files size if there is no index file
|
|
|
|
formats = [
|
|
|
|
(".safetensors",), # safetensors
|
|
|
|
(".bin",), # torch
|
|
|
|
(".onnx", ".pb"), # onnx
|
|
|
|
(".msgpack",), # flax
|
|
|
|
(".ckpt",), # tf
|
|
|
|
(".h5",), # tf2
|
|
|
|
]
|
|
|
|
|
|
|
|
for file_format in formats:
|
|
|
|
model_files = [f for f in files if f.endswith(file_format)]
|
|
|
|
if len(model_files) == 0:
|
|
|
|
continue
|
|
|
|
|
|
|
|
model_size = 0
|
|
|
|
for model_file in model_files:
|
|
|
|
file_stats = os.stat(os.path.join(model_path, model_file))
|
|
|
|
model_size += file_stats.st_size
|
|
|
|
return model_size
|
|
|
|
|
|
|
|
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
|
|
|
|
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
|
|
|
|
|
|
|
|
|
|
|
def calc_model_size_by_data(model) -> int:
|
|
|
|
if isinstance(model, DiffusionPipeline):
|
|
|
|
return _calc_pipeline_by_data(model)
|
|
|
|
elif isinstance(model, torch.nn.Module):
|
|
|
|
return _calc_model_by_data(model)
|
|
|
|
else:
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
def _calc_pipeline_by_data(pipeline) -> int:
|
|
|
|
res = 0
|
|
|
|
for submodel_key in pipeline.components.keys():
|
|
|
|
submodel = getattr(pipeline, submodel_key)
|
|
|
|
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
|
|
|
res += _calc_model_by_data(submodel)
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
def _calc_model_by_data(model) -> int:
|
|
|
|
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
|
|
|
|
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
|
|
|
mem = mem_params + mem_bufs # in bytes
|
|
|
|
return mem
|
|
|
|
|
|
|
|
|
2023-05-02 02:57:30 +00:00
|
|
|
|
2023-05-08 03:18:17 +00:00
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
class SDModelType(str, Enum):
|
2023-05-18 00:56:52 +00:00
|
|
|
Diffusers = "diffusers"
|
|
|
|
Classifier = "classifier"
|
|
|
|
UNet = "unet"
|
|
|
|
TextEncoder = "text_encoder"
|
|
|
|
Tokenizer = "tokenizer"
|
|
|
|
Vae = "vae"
|
|
|
|
Scheduler = "scheduler"
|
|
|
|
|
|
|
|
|
|
|
|
class ModelInfoBase:
|
|
|
|
#model_path: str
|
|
|
|
#model_type: SDModelType
|
|
|
|
|
|
|
|
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
|
|
|
|
self.repo_id_or_path = repo_id_or_path # TODO: or use allways path?
|
|
|
|
self.model_path = get_model_path(repo_id_or_path)
|
|
|
|
self.model_type = model_type
|
|
|
|
|
|
|
|
def _definition_to_type(self, subtypes: List[str]) -> Type:
|
|
|
|
if len(subtypes) < 2:
|
|
|
|
raise Exception("Invalid subfolder definition!")
|
|
|
|
if subtypes[0] in ["diffusers", "transformers"]:
|
|
|
|
res_type = sys.modules[subtypes[0]]
|
|
|
|
subtypes = subtypes[1:]
|
2023-05-14 00:06:26 +00:00
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
else:
|
|
|
|
res_type = sys.modules["diffusers"]
|
|
|
|
res_type = getattr(res_type, "pipelines")
|
2023-05-14 00:06:26 +00:00
|
|
|
|
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
for subtype in subtypes:
|
|
|
|
res_type = getattr(res_type, subtype)
|
|
|
|
return res_type
|
2023-05-14 22:09:38 +00:00
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
|
|
|
|
class DiffusersModelInfo(ModelInfoBase):
|
|
|
|
#child_types: Dict[str, Type]
|
|
|
|
#child_sizes: Dict[str, int]
|
|
|
|
|
|
|
|
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
|
|
|
|
assert model_type == SDModelType.Diffusers
|
|
|
|
super().__init__(repo_id_or_path, model_type)
|
|
|
|
|
|
|
|
self.child_types: Dict[str, Type] = dict()
|
|
|
|
self.child_sizes: Dict[str, int] = dict()
|
|
|
|
|
|
|
|
try:
|
|
|
|
config_data = DiffusionPipeline.load_config(repo_id_or_path)
|
|
|
|
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
|
|
|
except:
|
|
|
|
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
|
|
|
|
|
|
|
config_data.pop("_ignore_files", None)
|
|
|
|
|
|
|
|
# retrieve all folder_names that contain relevant files
|
|
|
|
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
|
|
|
|
|
|
|
|
for child_name in child_components:
|
|
|
|
child_type = self._definition_to_type(config_data[child_name])
|
|
|
|
self.child_types[child_name] = child_type
|
|
|
|
self.child_sizes[child_name] = calc_model_size_by_fs(repo_id_or_path, subfolder=child_name)
|
|
|
|
|
|
|
|
|
|
|
|
def get_size(self, child_type: Optional[SDModelType] = None):
|
|
|
|
if child_type is None:
|
|
|
|
return sum(self.child_sizes.values())
|
|
|
|
else:
|
|
|
|
return self.child_sizes[child_type]
|
|
|
|
|
|
|
|
|
|
|
|
def get_model(
|
|
|
|
self,
|
|
|
|
child_type: Optional[SDModelType] = None,
|
|
|
|
torch_dtype: Optional[torch.dtype] = None,
|
|
|
|
):
|
|
|
|
# return pipeline in different function to pass more arguments
|
|
|
|
if child_type is None:
|
|
|
|
raise Exception("Child model type can't be null on diffusers model")
|
|
|
|
if child_type not in self.child_types:
|
|
|
|
return None # TODO: or raise
|
|
|
|
|
|
|
|
# TODO:
|
|
|
|
for variant in ["fp16", "main", None]:
|
|
|
|
try:
|
|
|
|
model = self.child_types[child_type].from_pretrained(
|
|
|
|
self.repo_id_or_path,
|
|
|
|
subfolder=child_type.value,
|
|
|
|
cache_dir=global_cache_dir('hub'),
|
|
|
|
torch_dtype=torch_dtype,
|
|
|
|
variant=variant,
|
|
|
|
)
|
|
|
|
break
|
|
|
|
except Exception as e:
|
|
|
|
print("====ERR LOAD====")
|
|
|
|
print(f"{variant}: {e}")
|
|
|
|
|
|
|
|
# calc more accurate size
|
|
|
|
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def get_pipeline(self, **kwrags):
|
|
|
|
return DiffusionPipeline.from_pretrained(
|
|
|
|
self.repo_id_or_path,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class EmptyConfigLoader(ConfigMixin):
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def load_config(cls, *args, **kwargs):
|
|
|
|
cls.config_name = kwargs.pop("config_name")
|
|
|
|
return super().load_config(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
class ClassifierModelInfo(ModelInfoBase):
|
|
|
|
#child_types: Dict[str, Type]
|
|
|
|
#child_sizes: Dict[str, int]
|
|
|
|
|
|
|
|
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
|
|
|
|
assert model_type == SDModelType.Classifier
|
|
|
|
super().__init__(repo_id_or_path, model_type)
|
|
|
|
|
|
|
|
self.child_types: Dict[str, Type] = dict()
|
|
|
|
self.child_sizes: Dict[str, int] = dict()
|
|
|
|
|
|
|
|
try:
|
|
|
|
main_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="config.json")
|
|
|
|
#main_config = json.loads(os.path.join(self.model_path, "config.json"))
|
|
|
|
except:
|
|
|
|
raise Exception("Invalid classifier model! (config.json not found or invalid)")
|
|
|
|
|
|
|
|
self._load_tokenizer(main_config)
|
|
|
|
self._load_text_encoder(main_config)
|
|
|
|
self._load_feature_extractor(main_config)
|
|
|
|
|
|
|
|
|
|
|
|
def _load_tokenizer(self, main_config: dict):
|
|
|
|
try:
|
|
|
|
tokenizer_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="tokenizer_config.json")
|
|
|
|
#tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json"))
|
|
|
|
except:
|
|
|
|
raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)")
|
|
|
|
|
|
|
|
if "tokenizer_class" in tokenizer_config:
|
|
|
|
tokenizer_class_name = tokenizer_config["tokenizer_class"]
|
|
|
|
elif "model_type" in main_config:
|
|
|
|
tokenizer_class_name = transformers.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES[main_config["model_type"]]
|
|
|
|
else:
|
|
|
|
raise Exception("Invalid classifier model! (Failed to detect tokenizer type)")
|
|
|
|
|
|
|
|
self.child_types[SDModelType.Tokenizer] = self._definition_to_type(["transformers", tokenizer_class_name])
|
|
|
|
self.child_sizes[SDModelType.Tokenizer] = 0
|
|
|
|
|
|
|
|
|
|
|
|
def _load_text_encoder(self, main_config: dict):
|
|
|
|
if "architectures" in main_config and len(main_config["architectures"]) > 0:
|
|
|
|
text_encoder_class_name = main_config["architectures"][0]
|
|
|
|
elif "model_type" in main_config:
|
|
|
|
text_encoder_class_name = transformers.models.auto.modeling_auto.MODEL_FOR_PRETRAINING_MAPPING_NAMES[main_config["model_type"]]
|
|
|
|
else:
|
|
|
|
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
|
|
|
|
|
|
|
|
self.child_types[SDModelType.TextEncoder] = self._definition_to_type(["transformers", text_encoder_class_name])
|
|
|
|
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(repo_id_or_path)
|
|
|
|
|
|
|
|
|
|
|
|
def _load_feature_extractor(self, main_config: dict):
|
|
|
|
self.child_sizes[SDModelType.FeatureExtractor] = 0
|
|
|
|
try:
|
|
|
|
feature_extractor_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="preprocessor_config.json")
|
|
|
|
except:
|
|
|
|
return # feature extractor not passed with t5
|
|
|
|
|
|
|
|
try:
|
|
|
|
feature_extractor_class_name = feature_extractor_config["feature_extractor_type"]
|
|
|
|
self.child_types[SDModelType.FeatureExtractor] = self._definition_to_type(["transformers", feature_extractor_class_name])
|
|
|
|
except:
|
|
|
|
raise Exception("Invalid classifier model! (Unknown feature_extrator type)")
|
|
|
|
|
|
|
|
|
|
|
|
def get_size(self, child_type: Optional[SDModelType] = None):
|
|
|
|
if child_type is None:
|
|
|
|
return sum(self.child_sizes.values())
|
|
|
|
else:
|
|
|
|
return self.child_sizes[child_type]
|
|
|
|
|
|
|
|
|
|
|
|
def get_model(
|
|
|
|
self,
|
|
|
|
child_type: Optional[SDModelType] = None,
|
|
|
|
torch_dtype: Optional[torch.dtype] = None,
|
|
|
|
):
|
|
|
|
if child_type is None:
|
|
|
|
raise Exception("Child model type can't be null on classififer model")
|
|
|
|
if child_type not in self.child_types:
|
|
|
|
return None # TODO: or raise
|
2023-04-28 04:41:52 +00:00
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
model = self.child_types[child_type].from_pretrained(
|
|
|
|
self.repo_id_or_path,
|
|
|
|
subfolder=child_type.value,
|
|
|
|
cache_dir=global_cache_dir('hub'),
|
|
|
|
torch_dtype=torch_dtype,
|
|
|
|
)
|
|
|
|
# calc more accurate size
|
|
|
|
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
|
|
|
return model
|
2023-05-02 02:57:30 +00:00
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
|
|
|
|
class VaeModelInfo(ModelInfoBase):
|
|
|
|
#vae_class: Type
|
|
|
|
#model_size: int
|
|
|
|
|
|
|
|
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
|
|
|
|
assert model_type == SDModelType.Vae
|
|
|
|
super().__init__(repo_id_or_path, model_type)
|
|
|
|
|
|
|
|
try:
|
|
|
|
config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="config.json")
|
|
|
|
#config = json.loads(os.path.join(self.model_path, "config.json"))
|
|
|
|
except:
|
|
|
|
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
|
|
|
|
|
|
|
try:
|
|
|
|
vae_class_name = config.get("_class_name", "AutoencoderKL")
|
|
|
|
self.vae_class = self._definition_to_type(["diffusers", vae_class_name])
|
|
|
|
self.model_size = calc_model_size_by_fs(repo_id_or_path)
|
|
|
|
except:
|
|
|
|
raise Exception("Invalid vae model! (Unkown vae type)")
|
|
|
|
|
|
|
|
def get_size(self, child_type: Optional[SDModelType] = None):
|
|
|
|
if child_type is not None:
|
|
|
|
raise Exception("There is no child models in vae model")
|
|
|
|
return self.model_size
|
|
|
|
|
|
|
|
def get_model(
|
|
|
|
self,
|
|
|
|
child_type: Optional[SDModelType] = None,
|
|
|
|
torch_dtype: Optional[torch.dtype] = None,
|
|
|
|
):
|
|
|
|
if child_type is not None:
|
|
|
|
raise Exception("There is no child models in vae model")
|
|
|
|
|
|
|
|
model = self.vae_type.from_pretrained(
|
|
|
|
self.repo_id_or_path,
|
|
|
|
cache_dir=global_cache_dir('hub'),
|
|
|
|
torch_dtype=torch_dtype,
|
|
|
|
)
|
|
|
|
# calc more accurate size
|
|
|
|
self.model_size = calc_model_size_by_data(model)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_TYPES = {
|
|
|
|
SDModelType.Diffusers: DiffusersModelInfo,
|
|
|
|
SDModelType.Classifier: ClassifierModelInfo,
|
|
|
|
SDModelType.Vae: VaeModelInfo,
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# Maximum size of the cache, in gigs
|
|
|
|
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
|
|
|
DEFAULT_MAX_CACHE_SIZE = 6.0
|
|
|
|
|
|
|
|
# actual size of a gig
|
|
|
|
GIG = 1073741824
|
|
|
|
|
|
|
|
# TODO:
|
|
|
|
class EmptyScheduler(SchedulerMixin, ConfigMixin):
|
2023-05-03 16:38:18 +00:00
|
|
|
pass
|
2023-05-02 20:52:27 +00:00
|
|
|
|
2023-05-06 04:44:12 +00:00
|
|
|
class ModelLocker(object):
|
|
|
|
"Forward declaration"
|
|
|
|
pass
|
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
class ModelCache(object):
|
|
|
|
"Forward declaration"
|
|
|
|
pass
|
|
|
|
|
|
|
|
class _CacheRecord:
|
|
|
|
model: Any
|
|
|
|
size: int
|
|
|
|
_locks: int
|
|
|
|
_cache: ModelCache
|
|
|
|
|
|
|
|
def __init__(self, cache, model: Any, size: int):
|
|
|
|
self._cache = cache
|
|
|
|
self.model = model
|
|
|
|
self.size = size
|
|
|
|
self._locks = 0
|
|
|
|
|
|
|
|
def lock(self):
|
|
|
|
self._locks += 1
|
|
|
|
|
|
|
|
def unlock(self):
|
|
|
|
self._locks -= 1
|
|
|
|
assert self._locks >= 0
|
|
|
|
|
|
|
|
@property
|
|
|
|
def locked(self):
|
|
|
|
return self._locks > 0
|
|
|
|
|
|
|
|
@property
|
|
|
|
def loaded(self):
|
|
|
|
if self.model is not None and hasattr(self.model, "device"):
|
|
|
|
return self.model.device != self._cache.storage_device
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
2023-04-28 04:41:52 +00:00
|
|
|
class ModelCache(object):
|
|
|
|
def __init__(
|
2023-05-14 00:06:26 +00:00
|
|
|
self,
|
|
|
|
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
|
|
|
|
execution_device: torch.device=torch.device('cuda'),
|
|
|
|
storage_device: torch.device=torch.device('cpu'),
|
|
|
|
precision: torch.dtype=torch.float16,
|
|
|
|
sequential_offload: bool=False,
|
|
|
|
lazy_offloading: bool=True,
|
|
|
|
sha_chunksize: int = 16777216,
|
|
|
|
logger: types.ModuleType = logger
|
2023-04-28 04:41:52 +00:00
|
|
|
):
|
2023-05-02 20:52:27 +00:00
|
|
|
'''
|
2023-05-03 16:38:18 +00:00
|
|
|
:param max_models: Maximum number of models to cache in CPU RAM [4]
|
2023-05-02 20:52:27 +00:00
|
|
|
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
2023-05-03 16:38:18 +00:00
|
|
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
2023-05-02 20:52:27 +00:00
|
|
|
:param precision: Precision for loaded models [torch.float16]
|
2023-05-05 03:15:32 +00:00
|
|
|
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
2023-05-02 20:52:27 +00:00
|
|
|
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
|
|
|
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
|
|
|
'''
|
2023-05-18 00:56:52 +00:00
|
|
|
max_cache_size = 9999
|
|
|
|
execution_device = torch.device('cuda')
|
|
|
|
|
|
|
|
self.models: Dict[str, _CacheRecord] = dict()
|
|
|
|
self.model_infos: Dict[str, ModelInfoBase] = dict()
|
2023-04-28 04:41:52 +00:00
|
|
|
self.stack: Sequence = list()
|
2023-05-05 03:15:32 +00:00
|
|
|
self.lazy_offloading = lazy_offloading
|
2023-04-28 04:41:52 +00:00
|
|
|
self.sequential_offload: bool=sequential_offload
|
|
|
|
self.precision: torch.dtype=precision
|
2023-05-07 22:07:28 +00:00
|
|
|
self.current_cache_size: int=0
|
|
|
|
self.max_cache_size: int=max_cache_size
|
2023-05-03 16:38:18 +00:00
|
|
|
self.execution_device: torch.device=execution_device
|
|
|
|
self.storage_device: torch.device=storage_device
|
2023-05-02 20:52:27 +00:00
|
|
|
self.sha_chunksize=sha_chunksize
|
2023-05-05 23:32:28 +00:00
|
|
|
self.logger = logger
|
2023-05-18 00:56:52 +00:00
|
|
|
|
|
|
|
def get_key(
|
|
|
|
self,
|
|
|
|
model_path: str,
|
|
|
|
model_type: SDModelType,
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
submodel_type: Optional[SDModelType] = None,
|
|
|
|
):
|
|
|
|
revision = revision or "main"
|
|
|
|
|
|
|
|
key = f"{model_path}:{model_type}:{revision}"
|
|
|
|
if submodel_type:
|
|
|
|
key += f":{submodel_type}"
|
|
|
|
return key
|
|
|
|
|
|
|
|
#def get_model(
|
|
|
|
# self,
|
|
|
|
# repo_id_or_path: Union[str, Path],
|
|
|
|
# model_type: SDModelType = SDModelType.Diffusers,
|
|
|
|
# subfolder: Path = None,
|
|
|
|
# submodel: SDModelType = None,
|
|
|
|
# revision: str = None,
|
|
|
|
# attach_model_part: Tuple[SDModelType, str] = (None, None),
|
|
|
|
# gpu_load: bool = True,
|
|
|
|
#) -> ModelLocker: # ?? what does it return
|
|
|
|
def _get_model_info(
|
|
|
|
self,
|
|
|
|
model_path: str,
|
|
|
|
model_type: SDModelType,
|
|
|
|
revision: str,
|
|
|
|
):
|
|
|
|
model_info_key = self.get_key(
|
|
|
|
model_path=model_path,
|
|
|
|
model_type=model_type,
|
|
|
|
revision=revision,
|
|
|
|
submodel_type=None,
|
|
|
|
)
|
|
|
|
|
|
|
|
if model_info_key not in self.model_infos:
|
|
|
|
if model_type not in MODEL_TYPES:
|
|
|
|
raise Exception(f"Unknown/unsupported model type: {model_type}")
|
|
|
|
|
|
|
|
self.model_infos[model_info_key] = MODEL_TYPES[model_type](
|
|
|
|
model_path,
|
|
|
|
model_type,
|
|
|
|
)
|
|
|
|
|
|
|
|
return self.model_infos[model_info_key]
|
2023-04-28 04:41:52 +00:00
|
|
|
|
|
|
|
def get_model(
|
2023-05-14 00:06:26 +00:00
|
|
|
self,
|
|
|
|
repo_id_or_path: Union[str, Path],
|
|
|
|
model_type: SDModelType = SDModelType.Diffusers,
|
|
|
|
submodel: SDModelType = None,
|
|
|
|
revision: str = None,
|
|
|
|
gpu_load: bool = True,
|
2023-05-18 00:56:52 +00:00
|
|
|
) -> Any:
|
2023-05-08 01:39:11 +00:00
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
model_path = get_model_path(repo_id_or_path)
|
|
|
|
model_info = self._get_model_info(
|
|
|
|
model_path=model_path,
|
|
|
|
model_type=model_type,
|
|
|
|
revision=revision,
|
|
|
|
)
|
2023-05-08 01:39:11 +00:00
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
key = self.get_key(
|
|
|
|
model_path=model_path,
|
|
|
|
model_type=model_type,
|
|
|
|
revision=revision,
|
|
|
|
submodel_type=submodel,
|
|
|
|
)
|
2023-05-07 22:07:28 +00:00
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
if key not in self.models:
|
|
|
|
self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}:{submodel}')
|
2023-05-07 22:07:28 +00:00
|
|
|
|
|
|
|
# this will remove older cached models until
|
|
|
|
# there is sufficient room to load the requested model
|
2023-05-18 00:56:52 +00:00
|
|
|
self._make_cache_room(model_info.get_size(submodel))
|
2023-05-07 23:06:49 +00:00
|
|
|
|
|
|
|
# clean memory to make MemoryUsage() more accurate
|
|
|
|
gc.collect()
|
2023-05-18 00:56:52 +00:00
|
|
|
model_obj = model_info.get_model(submodel, torch_dtype=self.precision)
|
|
|
|
if mem_used := model_info.get_size(submodel):
|
2023-05-09 01:47:03 +00:00
|
|
|
logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
|
|
|
self.current_cache_size += mem_used # increment size of the cache
|
2023-05-14 20:45:40 +00:00
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
self.models[key] = _CacheRecord(self, model_obj, mem_used)
|
2023-05-14 20:45:40 +00:00
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
with suppress(Exception):
|
|
|
|
self.stack.remove(key)
|
|
|
|
self.stack.append(key)
|
2023-05-05 03:15:32 +00:00
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
return self.ModelLocker(self, key, self.models[key].model, gpu_load)
|
2023-05-06 04:44:12 +00:00
|
|
|
|
2023-05-06 19:58:44 +00:00
|
|
|
def uncache_model(self, key: str):
|
|
|
|
'''Remove corresponding model from the cache'''
|
2023-05-18 00:56:52 +00:00
|
|
|
self.models.pop(key, None)
|
|
|
|
with contextlib.suppress(ValueError):
|
|
|
|
self.stack.remove(key)
|
2023-05-06 19:58:44 +00:00
|
|
|
|
2023-05-06 04:44:12 +00:00
|
|
|
class ModelLocker(object):
|
|
|
|
def __init__(self, cache, key, model, gpu_load):
|
|
|
|
self.gpu_load = gpu_load
|
|
|
|
self.cache = cache
|
|
|
|
self.key = key
|
|
|
|
# This will keep a copy of the model in RAM until the locker
|
|
|
|
# is garbage collected. Needs testing!
|
|
|
|
self.model = model
|
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
def __enter__(self) -> Any:
|
|
|
|
if not hasattr(self.model, 'to'):
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
cache_entry = self.cache.models[self.key]
|
|
|
|
|
2023-05-10 03:46:59 +00:00
|
|
|
# NOTE that the model has to have the to() method in order for this
|
|
|
|
# code to move it into GPU!
|
2023-05-18 00:56:52 +00:00
|
|
|
if self.gpu_load:
|
|
|
|
cache_entry.lock()
|
2023-05-10 03:46:59 +00:00
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
if self.cache.lazy_offloading:
|
|
|
|
self.cache._offload_unlocked_models()
|
2023-05-10 03:46:59 +00:00
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
if self.model.device != self.cache.execution_device:
|
|
|
|
self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}')
|
2023-05-09 01:47:03 +00:00
|
|
|
with VRAMUsage() as mem:
|
2023-05-18 00:56:52 +00:00
|
|
|
self.model.to(self.cache.execution_device) # move into GPU
|
|
|
|
self.cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
|
2023-05-10 03:46:59 +00:00
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}')
|
|
|
|
self.cache._print_cuda_stats()
|
|
|
|
|
|
|
|
# TODO: not fully understand
|
|
|
|
# in the event that the caller wants the model in RAM, we
|
|
|
|
# move it into CPU if it is in GPU and not locked
|
|
|
|
elif cache_entry.loaded and not cache_entry.locked:
|
|
|
|
self.model.to(self.cache.storage_device)
|
|
|
|
|
|
|
|
return self.model
|
2023-05-06 04:44:12 +00:00
|
|
|
|
|
|
|
def __exit__(self, type, value, traceback):
|
2023-05-14 00:06:26 +00:00
|
|
|
if not hasattr(self.model, 'to'):
|
|
|
|
return
|
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
self.cache.models[self.key].unlock()
|
|
|
|
if not self.cache.lazy_offloading:
|
|
|
|
self.cache._offload_unlocked_models()
|
|
|
|
self.cache._print_cuda_stats()
|
2023-04-28 04:41:52 +00:00
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
def model_hash(
|
|
|
|
self,
|
|
|
|
repo_id_or_path: Union[str, Path],
|
|
|
|
revision: str = "main",
|
|
|
|
) -> str:
|
2023-05-02 20:52:27 +00:00
|
|
|
'''
|
|
|
|
Given the HF repo id or path to a model on disk, returns a unique
|
|
|
|
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
|
|
|
:param repo_id_or_path: repo_id string or Path to model file/directory on disk.
|
|
|
|
:param revision: optional revision string (if fetching a HF repo_id)
|
|
|
|
'''
|
2023-05-05 23:32:28 +00:00
|
|
|
revision = revision or "main"
|
2023-05-09 03:39:44 +00:00
|
|
|
if Path(repo_id_or_path).is_dir():
|
2023-05-02 20:52:27 +00:00
|
|
|
return self._local_model_hash(repo_id_or_path)
|
|
|
|
else:
|
|
|
|
return self._hf_commit_hash(repo_id_or_path,revision)
|
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
def cache_size(self) -> float:
|
2023-05-07 22:07:28 +00:00
|
|
|
"Return the current size of the cache, in GB"
|
|
|
|
return self.current_cache_size / GIG
|
2023-05-03 16:38:18 +00:00
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
def _has_cuda(self) -> bool:
|
2023-05-03 16:38:18 +00:00
|
|
|
return self.execution_device.type == 'cuda'
|
|
|
|
|
|
|
|
def _print_cuda_stats(self):
|
2023-05-07 22:07:28 +00:00
|
|
|
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
|
|
|
ram = "%4.2fG" % (self.current_cache_size / GIG)
|
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
loaded_models = 0
|
|
|
|
locked_models = 0
|
|
|
|
for cache_entry in self.models.values():
|
|
|
|
if cache_entry.loaded:
|
|
|
|
loaded_models += 1
|
|
|
|
if cache_entry.locked:
|
|
|
|
locked_models += 1
|
|
|
|
|
|
|
|
logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; locked_models/loaded_models = {locked_models}/{loaded_models}")
|
|
|
|
|
|
|
|
def _make_cache_room(self, model_size):
|
2023-05-07 22:07:28 +00:00
|
|
|
# calculate how much memory this model will require
|
2023-05-18 00:56:52 +00:00
|
|
|
#multiplier = 2 if self.precision==torch.float32 else 1
|
|
|
|
bytes_needed = model_size
|
2023-05-07 22:07:28 +00:00
|
|
|
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
|
|
|
current_size = self.current_cache_size
|
|
|
|
|
2023-05-18 00:56:52 +00:00
|
|
|
if current_size + bytes_needed > maximum_size:
|
|
|
|
logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
|
|
|
|
|
|
|
|
pos = 0
|
|
|
|
while current_size + bytes_needed > maximum_size and current_size > 0 and len(self.stack) > 0 and pos < len(self.stack):
|
|
|
|
model_key = self.stack[pos]
|
|
|
|
cache_entry = self.models[model_key]
|
|
|
|
if not cache_entry.locked:
|
|
|
|
logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
|
|
|
|
self.uncache_model(model_key) # del self.stack[pos]
|
|
|
|
current_size -= cache_entry.size
|
|
|
|
else:
|
|
|
|
pos += 1
|
|
|
|
|
2023-05-07 22:07:28 +00:00
|
|
|
self.current_cache_size = current_size
|
2023-04-28 04:41:52 +00:00
|
|
|
gc.collect()
|
|
|
|
|
2023-05-05 23:32:28 +00:00
|
|
|
def _offload_unlocked_models(self):
|
2023-05-18 00:56:52 +00:00
|
|
|
for key in self.models.keys():
|
|
|
|
cache_entry = self.models[key]
|
|
|
|
if not cache_entry.locked and cache_entry.loaded:
|
2023-05-05 23:32:28 +00:00
|
|
|
self.logger.debug(f'Offloading {key} from {self.execution_device} into {self.storage_device}')
|
2023-05-18 00:56:52 +00:00
|
|
|
cache_entry.model.to(self.storage_device)
|
2023-05-02 20:52:27 +00:00
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
2023-05-02 20:52:27 +00:00
|
|
|
sha = hashlib.sha256()
|
|
|
|
path = Path(model_path)
|
|
|
|
|
|
|
|
hashpath = path / "checksum.sha256"
|
|
|
|
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
|
|
|
with open(hashpath) as f:
|
|
|
|
hash = f.read()
|
|
|
|
return hash
|
|
|
|
|
2023-05-05 03:15:32 +00:00
|
|
|
logger.debug(f'computing hash of model {path.name}')
|
2023-05-02 20:52:27 +00:00
|
|
|
for file in list(path.rglob("*.ckpt")) \
|
|
|
|
+ list(path.rglob("*.safetensors")) \
|
|
|
|
+ list(path.rglob("*.pth")):
|
|
|
|
with open(file, "rb") as f:
|
|
|
|
while chunk := f.read(self.sha_chunksize):
|
|
|
|
sha.update(chunk)
|
|
|
|
hash = sha.hexdigest()
|
|
|
|
with open(hashpath, "w") as f:
|
|
|
|
f.write(hash)
|
|
|
|
return hash
|
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
def _hf_commit_hash(self, repo_id: str, revision: str='main') -> str:
|
2023-05-02 20:52:27 +00:00
|
|
|
api = HfApi()
|
|
|
|
info = api.list_repo_refs(
|
|
|
|
repo_id=repo_id,
|
|
|
|
repo_type='model',
|
|
|
|
)
|
|
|
|
desired_revisions = [branch for branch in info.branches if branch.name==revision]
|
|
|
|
if not desired_revisions:
|
|
|
|
raise KeyError(f"Revision '{revision}' not found in {repo_id}")
|
|
|
|
return desired_revisions[0].target_commit
|
|
|
|
|
2023-04-28 04:41:52 +00:00
|
|
|
class SilenceWarnings(object):
|
|
|
|
def __init__(self):
|
|
|
|
self.transformers_verbosity = transformers_logging.get_verbosity()
|
|
|
|
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
transformers_logging.set_verbosity_error()
|
|
|
|
diffusers_logging.set_verbosity_error()
|
|
|
|
warnings.simplefilter('ignore')
|
|
|
|
|
|
|
|
def __exit__(self,type,value,traceback):
|
|
|
|
transformers_logging.set_verbosity(self.transformers_verbosity)
|
|
|
|
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
|
|
|
warnings.simplefilter('default')
|
2023-05-07 22:07:28 +00:00
|
|
|
|
2023-05-09 01:47:03 +00:00
|
|
|
class VRAMUsage(object):
|
2023-05-07 22:07:28 +00:00
|
|
|
def __init__(self):
|
2023-05-09 01:47:03 +00:00
|
|
|
self.vram = None
|
|
|
|
self.vram_used = 0
|
2023-05-07 22:07:28 +00:00
|
|
|
|
|
|
|
def __enter__(self):
|
2023-05-09 01:47:03 +00:00
|
|
|
self.vram = torch.cuda.memory_allocated()
|
2023-05-07 22:07:28 +00:00
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, *args):
|
2023-05-09 01:47:03 +00:00
|
|
|
self.vram_used = torch.cuda.memory_allocated() - self.vram
|