mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ckpt model conversion now done in ModelCache
This commit is contained in:
parent
a108155544
commit
9cb962cad7
@ -150,7 +150,7 @@ class Generate:
|
|||||||
esrgan=None,
|
esrgan=None,
|
||||||
free_gpu_mem: bool = False,
|
free_gpu_mem: bool = False,
|
||||||
safety_checker: bool = False,
|
safety_checker: bool = False,
|
||||||
max_loaded_models: int = 2,
|
max_cache_size: int = 6,
|
||||||
# these are deprecated; if present they override values in the conf file
|
# these are deprecated; if present they override values in the conf file
|
||||||
weights=None,
|
weights=None,
|
||||||
config=None,
|
config=None,
|
||||||
@ -183,7 +183,7 @@ class Generate:
|
|||||||
self.codeformer = codeformer
|
self.codeformer = codeformer
|
||||||
self.esrgan = esrgan
|
self.esrgan = esrgan
|
||||||
self.free_gpu_mem = free_gpu_mem
|
self.free_gpu_mem = free_gpu_mem
|
||||||
self.max_loaded_models = (max_loaded_models,)
|
self.max_cache_size = max_cache_size
|
||||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||||
self.txt2mask = None
|
self.txt2mask = None
|
||||||
self.safety_checker = None
|
self.safety_checker = None
|
||||||
@ -220,7 +220,7 @@ class Generate:
|
|||||||
conf,
|
conf,
|
||||||
self.device,
|
self.device,
|
||||||
torch_dtype(self.device),
|
torch_dtype(self.device),
|
||||||
max_loaded_models=max_loaded_models,
|
max_cache_size=max_cache_size,
|
||||||
sequential_offload=self.free_gpu_mem,
|
sequential_offload=self.free_gpu_mem,
|
||||||
# embedding_path=Path(self.embedding_path),
|
# embedding_path=Path(self.embedding_path),
|
||||||
)
|
)
|
||||||
|
@ -94,6 +94,8 @@ def global_set_root(root_dir: Union[str, Path]):
|
|||||||
Globals.root = root_dir
|
Globals.root = root_dir
|
||||||
|
|
||||||
def global_resolve_path(path: Union[str,Path]):
|
def global_resolve_path(path: Union[str,Path]):
|
||||||
|
if path is None:
|
||||||
|
return None
|
||||||
return Path(Globals.root,path).resolve()
|
return Path(Globals.root,path).resolve()
|
||||||
|
|
||||||
def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
|
def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
|
||||||
|
@ -361,9 +361,10 @@ class ModelCache(object):
|
|||||||
)->ModelStatus:
|
)->ModelStatus:
|
||||||
key = self._model_key(
|
key = self._model_key(
|
||||||
repo_id_or_path,
|
repo_id_or_path,
|
||||||
model_type.value,
|
|
||||||
revision,
|
revision,
|
||||||
subfolder)
|
subfolder,
|
||||||
|
model_type.value,
|
||||||
|
)
|
||||||
if key not in self.models:
|
if key not in self.models:
|
||||||
return ModelStatus.not_loaded
|
return ModelStatus.not_loaded
|
||||||
if key in self.loaded_models:
|
if key in self.loaded_models:
|
||||||
@ -384,9 +385,7 @@ class ModelCache(object):
|
|||||||
:param revision: optional revision string (if fetching a HF repo_id)
|
:param revision: optional revision string (if fetching a HF repo_id)
|
||||||
'''
|
'''
|
||||||
revision = revision or "main"
|
revision = revision or "main"
|
||||||
if self.is_legacy_ckpt(repo_id_or_path):
|
if Path(repo_id_or_path).is_dir():
|
||||||
return self._legacy_model_hash(repo_id_or_path)
|
|
||||||
elif Path(repo_id_or_path).is_dir():
|
|
||||||
return self._local_model_hash(repo_id_or_path)
|
return self._local_model_hash(repo_id_or_path)
|
||||||
else:
|
else:
|
||||||
return self._hf_commit_hash(repo_id_or_path,revision)
|
return self._hf_commit_hash(repo_id_or_path,revision)
|
||||||
@ -395,15 +394,6 @@ class ModelCache(object):
|
|||||||
"Return the current size of the cache, in GB"
|
"Return the current size of the cache, in GB"
|
||||||
return self.current_cache_size / GIG
|
return self.current_cache_size / GIG
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_legacy_ckpt(cls, repo_id_or_path: Union[str,Path])->bool:
|
|
||||||
'''
|
|
||||||
Return true if the indicated path is a legacy checkpoint
|
|
||||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
|
||||||
'''
|
|
||||||
path = Path(repo_id_or_path)
|
|
||||||
return path.suffix in [".ckpt",".safetensors",".pt"]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def scan_model(cls, model_name, checkpoint):
|
def scan_model(cls, model_name, checkpoint):
|
||||||
"""
|
"""
|
||||||
@ -482,10 +472,6 @@ class ModelCache(object):
|
|||||||
'''
|
'''
|
||||||
# silence transformer and diffuser warnings
|
# silence transformer and diffuser warnings
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
# !!! NOTE: conversion should not happen here, but in ModelManager
|
|
||||||
if self.is_legacy_ckpt(repo_id_or_path):
|
|
||||||
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
|
|
||||||
else:
|
|
||||||
model = self._load_diffusers_from_storage(
|
model = self._load_diffusers_from_storage(
|
||||||
repo_id_or_path,
|
repo_id_or_path,
|
||||||
subfolder,
|
subfolder,
|
||||||
|
@ -143,7 +143,7 @@ from omegaconf import OmegaConf
|
|||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
|
|
||||||
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
|
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
|
||||||
from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, LegacyInfo
|
from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, SilenceWarnings
|
||||||
|
|
||||||
from ..util import CUDA_DEVICE
|
from ..util import CUDA_DEVICE
|
||||||
|
|
||||||
@ -225,12 +225,16 @@ class ModelManager(object):
|
|||||||
self.cache_keys = dict()
|
self.cache_keys = dict()
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
def valid_model(self, model_name: str) -> bool:
|
def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool:
|
||||||
"""
|
"""
|
||||||
Given a model name, returns True if it is a valid
|
Given a model name, returns True if it is a valid
|
||||||
identifier.
|
identifier.
|
||||||
"""
|
"""
|
||||||
return model_name in self.config
|
try:
|
||||||
|
self._disambiguate_name(model_name, model_type)
|
||||||
|
return True
|
||||||
|
except InvalidModelError:
|
||||||
|
return False
|
||||||
|
|
||||||
def get_model(self,
|
def get_model(self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -294,17 +298,17 @@ class ModelManager(object):
|
|||||||
model_parts = dict([(x.name,x) for x in SDModelType])
|
model_parts = dict([(x.name,x) for x in SDModelType])
|
||||||
legacy = None
|
legacy = None
|
||||||
|
|
||||||
if format=='ckpt':
|
if format == 'diffusers':
|
||||||
location = global_resolve_path(mconfig.weights)
|
# intercept stanzas that point to checkpoint weights and replace them
|
||||||
legacy = LegacyInfo(
|
# with the equivalent diffusers model
|
||||||
config_file = global_resolve_path(mconfig.config),
|
if 'weights' in mconfig:
|
||||||
)
|
location = self.convert_ckpt_and_cache(mconfig)
|
||||||
if mconfig.get('vae'):
|
else:
|
||||||
legacy.vae_file = global_resolve_path(mconfig.vae)
|
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
|
||||||
elif format=='diffusers':
|
|
||||||
location = mconfig.get('repo_id') or mconfig.get('path')
|
|
||||||
elif format in model_parts:
|
elif format in model_parts:
|
||||||
location = mconfig.get('repo_id') or mconfig.get('path') or mconfig.get('weights')
|
location = global_resolve_path(mconfig.get('path')) \
|
||||||
|
or mconfig.get('repo_id') \
|
||||||
|
or global_resolve_path(mconfig.get('weights'))
|
||||||
else:
|
else:
|
||||||
raise InvalidModelError(
|
raise InvalidModelError(
|
||||||
f'"{model_key}" has an unknown format {format}'
|
f'"{model_key}" has an unknown format {format}'
|
||||||
@ -531,7 +535,7 @@ class ModelManager(object):
|
|||||||
else:
|
else:
|
||||||
assert "weights" in model_attributes and "description" in model_attributes
|
assert "weights" in model_attributes and "description" in model_attributes
|
||||||
|
|
||||||
model_key = f'{model_name}/{format}'
|
model_key = f'{model_name}/{model_attributes["format"]}'
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
clobber or model_key not in omega
|
clobber or model_key not in omega
|
||||||
@ -776,7 +780,7 @@ class ModelManager(object):
|
|||||||
# another round of heuristics to guess the correct config file.
|
# another round of heuristics to guess the correct config file.
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
if model_path.suffix in [".ckpt", ".pt"]:
|
if model_path.suffix in [".ckpt", ".pt"]:
|
||||||
self.scan_model(model_path, model_path)
|
self.cache.scan_model(model_path, model_path)
|
||||||
checkpoint = torch.load(model_path)
|
checkpoint = torch.load(model_path)
|
||||||
else:
|
else:
|
||||||
checkpoint = safetensors.torch.load_file(model_path)
|
checkpoint = safetensors.torch.load_file(model_path)
|
||||||
@ -840,6 +844,7 @@ class ModelManager(object):
|
|||||||
diffuser_path = Path(
|
diffuser_path = Path(
|
||||||
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
|
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
|
||||||
)
|
)
|
||||||
|
with SilenceWarnings():
|
||||||
model_name = self.convert_and_import(
|
model_name = self.convert_and_import(
|
||||||
model_path,
|
model_path,
|
||||||
diffusers_path=diffuser_path,
|
diffusers_path=diffuser_path,
|
||||||
@ -853,6 +858,72 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
|
def convert_ckpt_and_cache(self, mconfig:DictConfig)->Path:
|
||||||
|
"""
|
||||||
|
Convert the checkpoint model indicated in mconfig into a
|
||||||
|
diffusers, cache it to disk, and return Path to converted
|
||||||
|
file. If already on disk then just returns Path.
|
||||||
|
"""
|
||||||
|
weights = global_resolve_path(mconfig.weights)
|
||||||
|
config_file = global_resolve_path(mconfig.config)
|
||||||
|
diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights.stem
|
||||||
|
|
||||||
|
# return cached version if it exists
|
||||||
|
if diffusers_path.exists():
|
||||||
|
return diffusers_path
|
||||||
|
|
||||||
|
vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
||||||
|
# to avoid circular import errors
|
||||||
|
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||||
|
with SilenceWarnings():
|
||||||
|
convert_ckpt_to_diffusers(
|
||||||
|
weights,
|
||||||
|
diffusers_path,
|
||||||
|
extract_ema=True,
|
||||||
|
original_config_file=config_file,
|
||||||
|
vae=vae_model,
|
||||||
|
vae_path=str(global_resolve_path(vae_ckpt_path)),
|
||||||
|
scan_needed=True,
|
||||||
|
)
|
||||||
|
return diffusers_path
|
||||||
|
|
||||||
|
def _get_vae_for_conversion(self,
|
||||||
|
weights: Path,
|
||||||
|
mconfig: DictConfig
|
||||||
|
)->tuple(Path,SDModelType.vae):
|
||||||
|
# VAE handling is convoluted
|
||||||
|
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use
|
||||||
|
# it as the vae_path passed to convert
|
||||||
|
vae_ckpt_path = None
|
||||||
|
vae_diffusers_location = None
|
||||||
|
vae_model = None
|
||||||
|
for suffix in ["pt", "ckpt", "safetensors"]:
|
||||||
|
if (weights.with_suffix(f".vae.{suffix}")).exists():
|
||||||
|
vae_ckpt_path = weights.with_suffix(f".vae.{suffix}")
|
||||||
|
self.logger.debug(f"Using VAE file {vae_ckpt_path.name}")
|
||||||
|
if vae_ckpt_path:
|
||||||
|
return (vae_ckpt_path, None)
|
||||||
|
|
||||||
|
# 2. If mconfig has a vae weights path, then we use that as vae_path
|
||||||
|
vae_config = mconfig.get('vae')
|
||||||
|
if vae_config and isinstance(vae_config,str):
|
||||||
|
vae_ckpt_path = vae_config
|
||||||
|
return (vae_ckpt_path, None)
|
||||||
|
|
||||||
|
# 3. If mconfig has a vae dict, then we use it as the diffusers-style vae
|
||||||
|
if vae_config and isinstance(vae_config,DictConfig):
|
||||||
|
vae_diffusers_location = global_resolve_path(vae_config.get('path')) or vae_config.get('repo_id')
|
||||||
|
|
||||||
|
# 4. Otherwise, we use stabilityai/sd-vae-ft-mse "because it works"
|
||||||
|
else:
|
||||||
|
vae_diffusers_location = "stabilityai/sd-vae-ft-mse"
|
||||||
|
|
||||||
|
if vae_diffusers_location:
|
||||||
|
vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.vae).model
|
||||||
|
return (None, vae_model)
|
||||||
|
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
def convert_and_import(
|
def convert_and_import(
|
||||||
self,
|
self,
|
||||||
ckpt_path: Path,
|
ckpt_path: Path,
|
||||||
@ -895,7 +966,8 @@ class ModelManager(object):
|
|||||||
# will be built into the model rather than tacked on afterward via the config file
|
# will be built into the model rather than tacked on afterward via the config file
|
||||||
vae_model = None
|
vae_model = None
|
||||||
if vae:
|
if vae:
|
||||||
vae_model = self._load_vae(vae)
|
vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id')
|
||||||
|
vae_model = self.cache.get_model(vae_location,SDModelType.vae).model
|
||||||
vae_path = None
|
vae_path = None
|
||||||
convert_ckpt_to_diffusers(
|
convert_ckpt_to_diffusers(
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
@ -982,9 +1054,9 @@ class ModelManager(object):
|
|||||||
def _disambiguate_name(self, model_name:str, model_type:SDModelType)->str:
|
def _disambiguate_name(self, model_name:str, model_type:SDModelType)->str:
|
||||||
model_type = model_type or SDModelType.diffusers
|
model_type = model_type or SDModelType.diffusers
|
||||||
full_name = f"{model_name}/{model_type.name}"
|
full_name = f"{model_name}/{model_type.name}"
|
||||||
if self.valid_model(full_name):
|
if full_name in self.config:
|
||||||
return full_name
|
return full_name
|
||||||
if self.valid_model(model_name):
|
if model_name in self.config:
|
||||||
return model_name
|
return model_name
|
||||||
raise InvalidModelError(
|
raise InvalidModelError(
|
||||||
f'Neither "{model_name}" nor "{full_name}" are known model names. Please check your models.yaml file'
|
f'Neither "{model_name}" nor "{full_name}" are known model names. Please check your models.yaml file'
|
||||||
@ -1014,3 +1086,20 @@ class ModelManager(object):
|
|||||||
return path
|
return path
|
||||||
return Path(Globals.root, path).resolve()
|
return Path(Globals.root, path).resolve()
|
||||||
|
|
||||||
|
# This is not the same as global_resolve_path(), which prepends
|
||||||
|
# Globals.root.
|
||||||
|
def _resolve_path(
|
||||||
|
self, source: Union[str, Path], dest_directory: str
|
||||||
|
) -> Optional[Path]:
|
||||||
|
resolved_path = None
|
||||||
|
if str(source).startswith(("http:", "https:", "ftp:")):
|
||||||
|
dest_directory = Path(dest_directory)
|
||||||
|
if not dest_directory.is_absolute():
|
||||||
|
dest_directory = Globals.root / dest_directory
|
||||||
|
dest_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
resolved_path = download_with_resume(str(source), dest_directory)
|
||||||
|
else:
|
||||||
|
if not os.path.isabs(source):
|
||||||
|
source = os.path.join(Globals.root, source)
|
||||||
|
resolved_path = Path(source)
|
||||||
|
return resolved_path
|
||||||
|
@ -54,10 +54,6 @@ def main():
|
|||||||
"--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead."
|
"--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead."
|
||||||
)
|
)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
if args.max_loaded_models is not None:
|
|
||||||
if args.max_loaded_models <= 0:
|
|
||||||
print("--max_loaded_models must be >= 1; using 1")
|
|
||||||
args.max_loaded_models = 1
|
|
||||||
|
|
||||||
# alert - setting a few globals here
|
# alert - setting a few globals here
|
||||||
Globals.try_patchmatch = args.patchmatch
|
Globals.try_patchmatch = args.patchmatch
|
||||||
@ -136,7 +132,7 @@ def main():
|
|||||||
esrgan=esrgan,
|
esrgan=esrgan,
|
||||||
free_gpu_mem=opt.free_gpu_mem,
|
free_gpu_mem=opt.free_gpu_mem,
|
||||||
safety_checker=opt.safety_checker,
|
safety_checker=opt.safety_checker,
|
||||||
max_loaded_models=opt.max_loaded_models,
|
max_cache_size=opt.max_cache_size,
|
||||||
)
|
)
|
||||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||||
report_model_error(opt, e)
|
report_model_error(opt, e)
|
||||||
|
Loading…
Reference in New Issue
Block a user