mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
convert custom VAEs into diffusers
- When a legacy checkpoint model is loaded via --convert_ckpt and its models.yaml stanza refers to a custom VAE path (using the 'vae:' key), the custom VAE will be converted and used within the diffusers model. Otherwise the VAE contained within the legacy model will be used. - Note that the heuristic_import() method, which imports arbitrary legacy files on disk and URLs, will continue to default to the the standard stabilityai/sd-vae-ft-mse VAE. This can be fixed after the fact by editing the models.yaml stanza using the Web or CLI UIs. - Fixes issue #2917
This commit is contained in:
parent
a958ae5e29
commit
4e0b5d85ba
@ -1033,7 +1033,7 @@ def replace_checkpoint_vae(checkpoint, vae_path:str):
|
|||||||
vae_ckpt = torch.load(vae_path, map_location="cpu")
|
vae_ckpt = torch.load(vae_path, map_location="cpu")
|
||||||
for vae_key in vae_ckpt['state_dict']:
|
for vae_key in vae_ckpt['state_dict']:
|
||||||
new_key = f'first_stage_model.{vae_key}'
|
new_key = f'first_stage_model.{vae_key}'
|
||||||
checkpoint[new_key] = vae_ckpt['state_dict'][vae_key]
|
checkpoint[new_key] = vae_ckpt['state_dict'][vae_key]
|
||||||
|
|
||||||
def load_pipeline_from_original_stable_diffusion_ckpt(
|
def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
|
@ -454,7 +454,12 @@ class ModelManager(object):
|
|||||||
|
|
||||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||||
|
|
||||||
self.offload_model(self.current_model)
|
try:
|
||||||
|
if self.list_models()[self.current_model]['status'] == 'active':
|
||||||
|
self.offload_model(self.current_model)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
vae_path = None
|
vae_path = None
|
||||||
if vae:
|
if vae:
|
||||||
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
|
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
|
||||||
@ -510,6 +515,7 @@ class ModelManager(object):
|
|||||||
print(f">> Offloading {model_name} to CPU")
|
print(f">> Offloading {model_name} to CPU")
|
||||||
model = self.models[model_name]["model"]
|
model = self.models[model_name]["model"]
|
||||||
model.offload_all()
|
model.offload_all()
|
||||||
|
self.current_model = None
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
@ -790,14 +796,15 @@ v Apply picklescanner to the indicated checkpoint and issue a warning
|
|||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def convert_and_import(
|
def convert_and_import(
|
||||||
self,
|
self,
|
||||||
ckpt_path: Path,
|
ckpt_path: Path,
|
||||||
diffusers_path: Path,
|
diffusers_path: Path,
|
||||||
model_name=None,
|
model_name=None,
|
||||||
model_description=None,
|
model_description=None,
|
||||||
vae=None,
|
vae:dict=None,
|
||||||
original_config_file: Path = None,
|
vae_path:Path=None,
|
||||||
commit_to_conf: Path = None,
|
original_config_file: Path = None,
|
||||||
|
commit_to_conf: Path = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a legacy ckpt weights file to diffuser model and import
|
Convert a legacy ckpt weights file to diffuser model and import
|
||||||
@ -825,13 +832,17 @@ v Apply picklescanner to the indicated checkpoint and issue a warning
|
|||||||
try:
|
try:
|
||||||
# By passing the specified VAE to the conversion function, the autoencoder
|
# By passing the specified VAE to the conversion function, the autoencoder
|
||||||
# will be built into the model rather than tacked on afterward via the config file
|
# will be built into the model rather than tacked on afterward via the config file
|
||||||
vae_model = self._load_vae(vae) if vae else None
|
vae_model=None
|
||||||
|
if vae:
|
||||||
|
vae_model=self._load_vae(vae)
|
||||||
|
vae_path=None
|
||||||
convert_ckpt_to_diffusers(
|
convert_ckpt_to_diffusers(
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
diffusers_path,
|
diffusers_path,
|
||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
original_config_file=original_config_file,
|
original_config_file=original_config_file,
|
||||||
vae=vae_model,
|
vae=vae_model,
|
||||||
|
vae_path=vae_path,
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f" | Success. Optimized model is now located at {str(diffusers_path)}"
|
f" | Success. Optimized model is now located at {str(diffusers_path)}"
|
||||||
|
@ -772,16 +772,10 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
|||||||
original_config_file = Path(model_info["config"])
|
original_config_file = Path(model_info["config"])
|
||||||
model_name = model_name_or_path
|
model_name = model_name_or_path
|
||||||
model_description = model_info["description"]
|
model_description = model_info["description"]
|
||||||
vae = model_info["vae"]
|
vae_path = model_info.get("vae")
|
||||||
else:
|
else:
|
||||||
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
|
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
|
||||||
return
|
return
|
||||||
if vae_repo := invokeai.backend.model_management.model_manager.VAE_TO_REPO_ID.get(
|
|
||||||
Path(vae).stem
|
|
||||||
):
|
|
||||||
vae_repo = dict(repo_id=vae_repo)
|
|
||||||
else:
|
|
||||||
vae_repo = None
|
|
||||||
model_name = manager.convert_and_import(
|
model_name = manager.convert_and_import(
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
diffusers_path=Path(
|
diffusers_path=Path(
|
||||||
@ -790,7 +784,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_description=model_description,
|
model_description=model_description,
|
||||||
original_config_file=original_config_file,
|
original_config_file=original_config_file,
|
||||||
vae=vae_repo,
|
vae_path=vae_path,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user