mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ckpt model conversion now done in ModelCache
This commit is contained in:
parent
a108155544
commit
9cb962cad7
@ -150,7 +150,7 @@ class Generate:
|
||||
esrgan=None,
|
||||
free_gpu_mem: bool = False,
|
||||
safety_checker: bool = False,
|
||||
max_loaded_models: int = 2,
|
||||
max_cache_size: int = 6,
|
||||
# these are deprecated; if present they override values in the conf file
|
||||
weights=None,
|
||||
config=None,
|
||||
@ -183,7 +183,7 @@ class Generate:
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
self.max_loaded_models = (max_loaded_models,)
|
||||
self.max_cache_size = max_cache_size
|
||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||
self.txt2mask = None
|
||||
self.safety_checker = None
|
||||
@ -220,7 +220,7 @@ class Generate:
|
||||
conf,
|
||||
self.device,
|
||||
torch_dtype(self.device),
|
||||
max_loaded_models=max_loaded_models,
|
||||
max_cache_size=max_cache_size,
|
||||
sequential_offload=self.free_gpu_mem,
|
||||
# embedding_path=Path(self.embedding_path),
|
||||
)
|
||||
|
@ -94,6 +94,8 @@ def global_set_root(root_dir: Union[str, Path]):
|
||||
Globals.root = root_dir
|
||||
|
||||
def global_resolve_path(path: Union[str,Path]):
|
||||
if path is None:
|
||||
return None
|
||||
return Path(Globals.root,path).resolve()
|
||||
|
||||
def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
|
||||
|
@ -361,9 +361,10 @@ class ModelCache(object):
|
||||
)->ModelStatus:
|
||||
key = self._model_key(
|
||||
repo_id_or_path,
|
||||
model_type.value,
|
||||
revision,
|
||||
subfolder)
|
||||
subfolder,
|
||||
model_type.value,
|
||||
)
|
||||
if key not in self.models:
|
||||
return ModelStatus.not_loaded
|
||||
if key in self.loaded_models:
|
||||
@ -384,9 +385,7 @@ class ModelCache(object):
|
||||
:param revision: optional revision string (if fetching a HF repo_id)
|
||||
'''
|
||||
revision = revision or "main"
|
||||
if self.is_legacy_ckpt(repo_id_or_path):
|
||||
return self._legacy_model_hash(repo_id_or_path)
|
||||
elif Path(repo_id_or_path).is_dir():
|
||||
if Path(repo_id_or_path).is_dir():
|
||||
return self._local_model_hash(repo_id_or_path)
|
||||
else:
|
||||
return self._hf_commit_hash(repo_id_or_path,revision)
|
||||
@ -395,15 +394,6 @@ class ModelCache(object):
|
||||
"Return the current size of the cache, in GB"
|
||||
return self.current_cache_size / GIG
|
||||
|
||||
@classmethod
|
||||
def is_legacy_ckpt(cls, repo_id_or_path: Union[str,Path])->bool:
|
||||
'''
|
||||
Return true if the indicated path is a legacy checkpoint
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
'''
|
||||
path = Path(repo_id_or_path)
|
||||
return path.suffix in [".ckpt",".safetensors",".pt"]
|
||||
|
||||
@classmethod
|
||||
def scan_model(cls, model_name, checkpoint):
|
||||
"""
|
||||
@ -482,16 +472,12 @@ class ModelCache(object):
|
||||
'''
|
||||
# silence transformer and diffuser warnings
|
||||
with SilenceWarnings():
|
||||
# !!! NOTE: conversion should not happen here, but in ModelManager
|
||||
if self.is_legacy_ckpt(repo_id_or_path):
|
||||
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
|
||||
else:
|
||||
model = self._load_diffusers_from_storage(
|
||||
repo_id_or_path,
|
||||
subfolder,
|
||||
revision,
|
||||
model_class,
|
||||
)
|
||||
model = self._load_diffusers_from_storage(
|
||||
repo_id_or_path,
|
||||
subfolder,
|
||||
revision,
|
||||
model_class,
|
||||
)
|
||||
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
||||
model.enable_offload_submodels(self.execution_device)
|
||||
return model
|
||||
|
@ -143,7 +143,7 @@ from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
|
||||
from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, LegacyInfo
|
||||
from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, SilenceWarnings
|
||||
|
||||
from ..util import CUDA_DEVICE
|
||||
|
||||
@ -225,12 +225,16 @@ class ModelManager(object):
|
||||
self.cache_keys = dict()
|
||||
self.logger = logger
|
||||
|
||||
def valid_model(self, model_name: str) -> bool:
|
||||
def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
"""
|
||||
return model_name in self.config
|
||||
try:
|
||||
self._disambiguate_name(model_name, model_type)
|
||||
return True
|
||||
except InvalidModelError:
|
||||
return False
|
||||
|
||||
def get_model(self,
|
||||
model_name: str,
|
||||
@ -294,17 +298,17 @@ class ModelManager(object):
|
||||
model_parts = dict([(x.name,x) for x in SDModelType])
|
||||
legacy = None
|
||||
|
||||
if format=='ckpt':
|
||||
location = global_resolve_path(mconfig.weights)
|
||||
legacy = LegacyInfo(
|
||||
config_file = global_resolve_path(mconfig.config),
|
||||
)
|
||||
if mconfig.get('vae'):
|
||||
legacy.vae_file = global_resolve_path(mconfig.vae)
|
||||
elif format=='diffusers':
|
||||
location = mconfig.get('repo_id') or mconfig.get('path')
|
||||
if format == 'diffusers':
|
||||
# intercept stanzas that point to checkpoint weights and replace them
|
||||
# with the equivalent diffusers model
|
||||
if 'weights' in mconfig:
|
||||
location = self.convert_ckpt_and_cache(mconfig)
|
||||
else:
|
||||
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
|
||||
elif format in model_parts:
|
||||
location = mconfig.get('repo_id') or mconfig.get('path') or mconfig.get('weights')
|
||||
location = global_resolve_path(mconfig.get('path')) \
|
||||
or mconfig.get('repo_id') \
|
||||
or global_resolve_path(mconfig.get('weights'))
|
||||
else:
|
||||
raise InvalidModelError(
|
||||
f'"{model_key}" has an unknown format {format}'
|
||||
@ -531,7 +535,7 @@ class ModelManager(object):
|
||||
else:
|
||||
assert "weights" in model_attributes and "description" in model_attributes
|
||||
|
||||
model_key = f'{model_name}/{format}'
|
||||
model_key = f'{model_name}/{model_attributes["format"]}'
|
||||
|
||||
assert (
|
||||
clobber or model_key not in omega
|
||||
@ -776,7 +780,7 @@ class ModelManager(object):
|
||||
# another round of heuristics to guess the correct config file.
|
||||
checkpoint = None
|
||||
if model_path.suffix in [".ckpt", ".pt"]:
|
||||
self.scan_model(model_path, model_path)
|
||||
self.cache.scan_model(model_path, model_path)
|
||||
checkpoint = torch.load(model_path)
|
||||
else:
|
||||
checkpoint = safetensors.torch.load_file(model_path)
|
||||
@ -840,19 +844,86 @@ class ModelManager(object):
|
||||
diffuser_path = Path(
|
||||
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
|
||||
)
|
||||
model_name = self.convert_and_import(
|
||||
model_path,
|
||||
diffusers_path=diffuser_path,
|
||||
vae=vae,
|
||||
vae_path=str(vae_path),
|
||||
model_name=model_name,
|
||||
model_description=description,
|
||||
original_config_file=model_config_file,
|
||||
commit_to_conf=commit_to_conf,
|
||||
scan_needed=False,
|
||||
)
|
||||
with SilenceWarnings():
|
||||
model_name = self.convert_and_import(
|
||||
model_path,
|
||||
diffusers_path=diffuser_path,
|
||||
vae=vae,
|
||||
vae_path=str(vae_path),
|
||||
model_name=model_name,
|
||||
model_description=description,
|
||||
original_config_file=model_config_file,
|
||||
commit_to_conf=commit_to_conf,
|
||||
scan_needed=False,
|
||||
)
|
||||
return model_name
|
||||
|
||||
def convert_ckpt_and_cache(self, mconfig:DictConfig)->Path:
|
||||
"""
|
||||
Convert the checkpoint model indicated in mconfig into a
|
||||
diffusers, cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
weights = global_resolve_path(mconfig.weights)
|
||||
config_file = global_resolve_path(mconfig.config)
|
||||
diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights.stem
|
||||
|
||||
# return cached version if it exists
|
||||
if diffusers_path.exists():
|
||||
return diffusers_path
|
||||
|
||||
vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
||||
# to avoid circular import errors
|
||||
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
with SilenceWarnings():
|
||||
convert_ckpt_to_diffusers(
|
||||
weights,
|
||||
diffusers_path,
|
||||
extract_ema=True,
|
||||
original_config_file=config_file,
|
||||
vae=vae_model,
|
||||
vae_path=str(global_resolve_path(vae_ckpt_path)),
|
||||
scan_needed=True,
|
||||
)
|
||||
return diffusers_path
|
||||
|
||||
def _get_vae_for_conversion(self,
|
||||
weights: Path,
|
||||
mconfig: DictConfig
|
||||
)->tuple(Path,SDModelType.vae):
|
||||
# VAE handling is convoluted
|
||||
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use
|
||||
# it as the vae_path passed to convert
|
||||
vae_ckpt_path = None
|
||||
vae_diffusers_location = None
|
||||
vae_model = None
|
||||
for suffix in ["pt", "ckpt", "safetensors"]:
|
||||
if (weights.with_suffix(f".vae.{suffix}")).exists():
|
||||
vae_ckpt_path = weights.with_suffix(f".vae.{suffix}")
|
||||
self.logger.debug(f"Using VAE file {vae_ckpt_path.name}")
|
||||
if vae_ckpt_path:
|
||||
return (vae_ckpt_path, None)
|
||||
|
||||
# 2. If mconfig has a vae weights path, then we use that as vae_path
|
||||
vae_config = mconfig.get('vae')
|
||||
if vae_config and isinstance(vae_config,str):
|
||||
vae_ckpt_path = vae_config
|
||||
return (vae_ckpt_path, None)
|
||||
|
||||
# 3. If mconfig has a vae dict, then we use it as the diffusers-style vae
|
||||
if vae_config and isinstance(vae_config,DictConfig):
|
||||
vae_diffusers_location = global_resolve_path(vae_config.get('path')) or vae_config.get('repo_id')
|
||||
|
||||
# 4. Otherwise, we use stabilityai/sd-vae-ft-mse "because it works"
|
||||
else:
|
||||
vae_diffusers_location = "stabilityai/sd-vae-ft-mse"
|
||||
|
||||
if vae_diffusers_location:
|
||||
vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.vae).model
|
||||
return (None, vae_model)
|
||||
|
||||
return (None, None)
|
||||
|
||||
def convert_and_import(
|
||||
self,
|
||||
ckpt_path: Path,
|
||||
@ -895,7 +966,8 @@ class ModelManager(object):
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
vae_model = None
|
||||
if vae:
|
||||
vae_model = self._load_vae(vae)
|
||||
vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id')
|
||||
vae_model = self.cache.get_model(vae_location,SDModelType.vae).model
|
||||
vae_path = None
|
||||
convert_ckpt_to_diffusers(
|
||||
ckpt_path,
|
||||
@ -982,9 +1054,9 @@ class ModelManager(object):
|
||||
def _disambiguate_name(self, model_name:str, model_type:SDModelType)->str:
|
||||
model_type = model_type or SDModelType.diffusers
|
||||
full_name = f"{model_name}/{model_type.name}"
|
||||
if self.valid_model(full_name):
|
||||
if full_name in self.config:
|
||||
return full_name
|
||||
if self.valid_model(model_name):
|
||||
if model_name in self.config:
|
||||
return model_name
|
||||
raise InvalidModelError(
|
||||
f'Neither "{model_name}" nor "{full_name}" are known model names. Please check your models.yaml file'
|
||||
@ -1014,3 +1086,20 @@ class ModelManager(object):
|
||||
return path
|
||||
return Path(Globals.root, path).resolve()
|
||||
|
||||
# This is not the same as global_resolve_path(), which prepends
|
||||
# Globals.root.
|
||||
def _resolve_path(
|
||||
self, source: Union[str, Path], dest_directory: str
|
||||
) -> Optional[Path]:
|
||||
resolved_path = None
|
||||
if str(source).startswith(("http:", "https:", "ftp:")):
|
||||
dest_directory = Path(dest_directory)
|
||||
if not dest_directory.is_absolute():
|
||||
dest_directory = Globals.root / dest_directory
|
||||
dest_directory.mkdir(parents=True, exist_ok=True)
|
||||
resolved_path = download_with_resume(str(source), dest_directory)
|
||||
else:
|
||||
if not os.path.isabs(source):
|
||||
source = os.path.join(Globals.root, source)
|
||||
resolved_path = Path(source)
|
||||
return resolved_path
|
||||
|
@ -54,10 +54,6 @@ def main():
|
||||
"--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead."
|
||||
)
|
||||
sys.exit(-1)
|
||||
if args.max_loaded_models is not None:
|
||||
if args.max_loaded_models <= 0:
|
||||
print("--max_loaded_models must be >= 1; using 1")
|
||||
args.max_loaded_models = 1
|
||||
|
||||
# alert - setting a few globals here
|
||||
Globals.try_patchmatch = args.patchmatch
|
||||
@ -136,7 +132,7 @@ def main():
|
||||
esrgan=esrgan,
|
||||
free_gpu_mem=opt.free_gpu_mem,
|
||||
safety_checker=opt.safety_checker,
|
||||
max_loaded_models=opt.max_loaded_models,
|
||||
max_cache_size=opt.max_cache_size,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(opt, e)
|
||||
|
Loading…
Reference in New Issue
Block a user