ckpt model conversion now done in ModelCache

This commit is contained in:
Lincoln Stein 2023-05-08 23:39:44 -04:00
parent a108155544
commit 9cb962cad7
5 changed files with 134 additions and 61 deletions

View File

@ -150,7 +150,7 @@ class Generate:
esrgan=None,
free_gpu_mem: bool = False,
safety_checker: bool = False,
max_loaded_models: int = 2,
max_cache_size: int = 6,
# these are deprecated; if present they override values in the conf file
weights=None,
config=None,
@ -183,7 +183,7 @@ class Generate:
self.codeformer = codeformer
self.esrgan = esrgan
self.free_gpu_mem = free_gpu_mem
self.max_loaded_models = (max_loaded_models,)
self.max_cache_size = max_cache_size
self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None
self.safety_checker = None
@ -220,7 +220,7 @@ class Generate:
conf,
self.device,
torch_dtype(self.device),
max_loaded_models=max_loaded_models,
max_cache_size=max_cache_size,
sequential_offload=self.free_gpu_mem,
# embedding_path=Path(self.embedding_path),
)

View File

@ -94,6 +94,8 @@ def global_set_root(root_dir: Union[str, Path]):
Globals.root = root_dir
def global_resolve_path(path: Union[str,Path]):
if path is None:
return None
return Path(Globals.root,path).resolve()
def global_cache_dir(subdir: Union[str, Path] = "") -> Path:

View File

@ -361,9 +361,10 @@ class ModelCache(object):
)->ModelStatus:
key = self._model_key(
repo_id_or_path,
model_type.value,
revision,
subfolder)
subfolder,
model_type.value,
)
if key not in self.models:
return ModelStatus.not_loaded
if key in self.loaded_models:
@ -384,9 +385,7 @@ class ModelCache(object):
:param revision: optional revision string (if fetching a HF repo_id)
'''
revision = revision or "main"
if self.is_legacy_ckpt(repo_id_or_path):
return self._legacy_model_hash(repo_id_or_path)
elif Path(repo_id_or_path).is_dir():
if Path(repo_id_or_path).is_dir():
return self._local_model_hash(repo_id_or_path)
else:
return self._hf_commit_hash(repo_id_or_path,revision)
@ -395,15 +394,6 @@ class ModelCache(object):
"Return the current size of the cache, in GB"
return self.current_cache_size / GIG
@classmethod
def is_legacy_ckpt(cls, repo_id_or_path: Union[str,Path])->bool:
'''
Return true if the indicated path is a legacy checkpoint
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
'''
path = Path(repo_id_or_path)
return path.suffix in [".ckpt",".safetensors",".pt"]
@classmethod
def scan_model(cls, model_name, checkpoint):
"""
@ -482,16 +472,12 @@ class ModelCache(object):
'''
# silence transformer and diffuser warnings
with SilenceWarnings():
# !!! NOTE: conversion should not happen here, but in ModelManager
if self.is_legacy_ckpt(repo_id_or_path):
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
else:
model = self._load_diffusers_from_storage(
repo_id_or_path,
subfolder,
revision,
model_class,
)
model = self._load_diffusers_from_storage(
repo_id_or_path,
subfolder,
revision,
model_class,
)
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
model.enable_offload_submodels(self.execution_device)
return model

View File

@ -143,7 +143,7 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, LegacyInfo
from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, SilenceWarnings
from ..util import CUDA_DEVICE
@ -225,12 +225,16 @@ class ModelManager(object):
self.cache_keys = dict()
self.logger = logger
def valid_model(self, model_name: str) -> bool:
def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
"""
return model_name in self.config
try:
self._disambiguate_name(model_name, model_type)
return True
except InvalidModelError:
return False
def get_model(self,
model_name: str,
@ -294,17 +298,17 @@ class ModelManager(object):
model_parts = dict([(x.name,x) for x in SDModelType])
legacy = None
if format=='ckpt':
location = global_resolve_path(mconfig.weights)
legacy = LegacyInfo(
config_file = global_resolve_path(mconfig.config),
)
if mconfig.get('vae'):
legacy.vae_file = global_resolve_path(mconfig.vae)
elif format=='diffusers':
location = mconfig.get('repo_id') or mconfig.get('path')
if format == 'diffusers':
# intercept stanzas that point to checkpoint weights and replace them
# with the equivalent diffusers model
if 'weights' in mconfig:
location = self.convert_ckpt_and_cache(mconfig)
else:
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
elif format in model_parts:
location = mconfig.get('repo_id') or mconfig.get('path') or mconfig.get('weights')
location = global_resolve_path(mconfig.get('path')) \
or mconfig.get('repo_id') \
or global_resolve_path(mconfig.get('weights'))
else:
raise InvalidModelError(
f'"{model_key}" has an unknown format {format}'
@ -531,7 +535,7 @@ class ModelManager(object):
else:
assert "weights" in model_attributes and "description" in model_attributes
model_key = f'{model_name}/{format}'
model_key = f'{model_name}/{model_attributes["format"]}'
assert (
clobber or model_key not in omega
@ -776,7 +780,7 @@ class ModelManager(object):
# another round of heuristics to guess the correct config file.
checkpoint = None
if model_path.suffix in [".ckpt", ".pt"]:
self.scan_model(model_path, model_path)
self.cache.scan_model(model_path, model_path)
checkpoint = torch.load(model_path)
else:
checkpoint = safetensors.torch.load_file(model_path)
@ -840,19 +844,86 @@ class ModelManager(object):
diffuser_path = Path(
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
)
model_name = self.convert_and_import(
model_path,
diffusers_path=diffuser_path,
vae=vae,
vae_path=str(vae_path),
model_name=model_name,
model_description=description,
original_config_file=model_config_file,
commit_to_conf=commit_to_conf,
scan_needed=False,
)
with SilenceWarnings():
model_name = self.convert_and_import(
model_path,
diffusers_path=diffuser_path,
vae=vae,
vae_path=str(vae_path),
model_name=model_name,
model_description=description,
original_config_file=model_config_file,
commit_to_conf=commit_to_conf,
scan_needed=False,
)
return model_name
def convert_ckpt_and_cache(self, mconfig:DictConfig)->Path:
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
weights = global_resolve_path(mconfig.weights)
config_file = global_resolve_path(mconfig.config)
diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights.stem
# return cached version if it exists
if diffusers_path.exists():
return diffusers_path
vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
# to avoid circular import errors
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
diffusers_path,
extract_ema=True,
original_config_file=config_file,
vae=vae_model,
vae_path=str(global_resolve_path(vae_ckpt_path)),
scan_needed=True,
)
return diffusers_path
def _get_vae_for_conversion(self,
weights: Path,
mconfig: DictConfig
)->tuple(Path,SDModelType.vae):
# VAE handling is convoluted
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use
# it as the vae_path passed to convert
vae_ckpt_path = None
vae_diffusers_location = None
vae_model = None
for suffix in ["pt", "ckpt", "safetensors"]:
if (weights.with_suffix(f".vae.{suffix}")).exists():
vae_ckpt_path = weights.with_suffix(f".vae.{suffix}")
self.logger.debug(f"Using VAE file {vae_ckpt_path.name}")
if vae_ckpt_path:
return (vae_ckpt_path, None)
# 2. If mconfig has a vae weights path, then we use that as vae_path
vae_config = mconfig.get('vae')
if vae_config and isinstance(vae_config,str):
vae_ckpt_path = vae_config
return (vae_ckpt_path, None)
# 3. If mconfig has a vae dict, then we use it as the diffusers-style vae
if vae_config and isinstance(vae_config,DictConfig):
vae_diffusers_location = global_resolve_path(vae_config.get('path')) or vae_config.get('repo_id')
# 4. Otherwise, we use stabilityai/sd-vae-ft-mse "because it works"
else:
vae_diffusers_location = "stabilityai/sd-vae-ft-mse"
if vae_diffusers_location:
vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.vae).model
return (None, vae_model)
return (None, None)
def convert_and_import(
self,
ckpt_path: Path,
@ -895,7 +966,8 @@ class ModelManager(object):
# will be built into the model rather than tacked on afterward via the config file
vae_model = None
if vae:
vae_model = self._load_vae(vae)
vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id')
vae_model = self.cache.get_model(vae_location,SDModelType.vae).model
vae_path = None
convert_ckpt_to_diffusers(
ckpt_path,
@ -982,9 +1054,9 @@ class ModelManager(object):
def _disambiguate_name(self, model_name:str, model_type:SDModelType)->str:
model_type = model_type or SDModelType.diffusers
full_name = f"{model_name}/{model_type.name}"
if self.valid_model(full_name):
if full_name in self.config:
return full_name
if self.valid_model(model_name):
if model_name in self.config:
return model_name
raise InvalidModelError(
f'Neither "{model_name}" nor "{full_name}" are known model names. Please check your models.yaml file'
@ -1014,3 +1086,20 @@ class ModelManager(object):
return path
return Path(Globals.root, path).resolve()
# This is not the same as global_resolve_path(), which prepends
# Globals.root.
def _resolve_path(
self, source: Union[str, Path], dest_directory: str
) -> Optional[Path]:
resolved_path = None
if str(source).startswith(("http:", "https:", "ftp:")):
dest_directory = Path(dest_directory)
if not dest_directory.is_absolute():
dest_directory = Globals.root / dest_directory
dest_directory.mkdir(parents=True, exist_ok=True)
resolved_path = download_with_resume(str(source), dest_directory)
else:
if not os.path.isabs(source):
source = os.path.join(Globals.root, source)
resolved_path = Path(source)
return resolved_path

View File

@ -54,10 +54,6 @@ def main():
"--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead."
)
sys.exit(-1)
if args.max_loaded_models is not None:
if args.max_loaded_models <= 0:
print("--max_loaded_models must be >= 1; using 1")
args.max_loaded_models = 1
# alert - setting a few globals here
Globals.try_patchmatch = args.patchmatch
@ -136,7 +132,7 @@ def main():
esrgan=esrgan,
free_gpu_mem=opt.free_gpu_mem,
safety_checker=opt.safety_checker,
max_loaded_models=opt.max_loaded_models,
max_cache_size=opt.max_cache_size,
)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(opt, e)