Convert custom VAEs during legacy checkpoint loading (#3010)

- When a legacy checkpoint model is loaded via --convert_ckpt and its
models.yaml stanza refers to a custom VAE path (using the 'vae:' key),
the custom VAE will be converted and used within the diffusers model.
Otherwise the VAE contained within the legacy model will be used.
    
- Note that the checkpoint import functions in the CLI or Web UIs
continue to default to the standard stabilityai/sd-vae-ft-mse VAE. This
can be fixed after the fact by editing VAE key using either the CLI or
Web UI.
   
- Fixes issue #2917
This commit is contained in:
Lincoln Stein 2023-03-25 00:37:12 -04:00 committed by GitHub
commit 9536ba22af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 60 deletions

View File

@ -1036,6 +1036,15 @@ def convert_open_clip_checkpoint(checkpoint):
return text_model
def replace_checkpoint_vae(checkpoint, vae_path:str):
if vae_path.endswith(".safetensors"):
vae_ckpt = load_file(vae_path)
else:
vae_ckpt = torch.load(vae_path, map_location="cpu")
state_dict = vae_ckpt['state_dict'] if "state_dict" in vae_ckpt else vae_ckpt
for vae_key in state_dict:
new_key = f'first_stage_model.{vae_key}'
checkpoint[new_key] = state_dict[vae_key]
def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path: str,
@ -1048,6 +1057,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
extract_ema: bool = True,
upcast_attn: bool = False,
vae: AutoencoderKL = None,
vae_path: str = None,
precision: torch.dtype = torch.float32,
return_generator_pipeline: bool = False,
scan_needed:bool=True,
@ -1078,6 +1088,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
running stable diffusion 2.1.
:param vae: A diffusers VAE to load into the pipeline.
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
"""
with warnings.catch_warnings():
@ -1214,9 +1226,19 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model, or use the one passed
if not vae:
print(" | Using checkpoint model's original VAE")
# If a replacement VAE path was specified, we'll incorporate that into
# the checkpoint model and then convert it
if vae_path:
print(f" | Converting VAE {vae_path}")
replace_checkpoint_vae(checkpoint,vae_path)
# otherwise we use the original VAE, provided that
# an externally loaded diffusers VAE was not passed
elif not vae:
print(" | Using checkpoint model's original VAE")
if vae:
print(" | Using replacement diffusers VAE")
else: # convert the original or replacement VAE
vae_config = create_vae_diffusers_config(
original_config, image_size=image_size
)
@ -1226,8 +1248,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
else:
print(" | Using external VAE specified in config")
# Convert the text model.
model_type = pipeline_type

View File

@ -45,9 +45,6 @@ class SDLegacyType(Enum):
UNKNOWN = 99
DEFAULT_MAX_MODELS = 2
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
"vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse",
}
class ModelManager(object):
'''
@ -457,15 +454,21 @@ class ModelManager(object):
from . import load_pipeline_from_original_stable_diffusion_ckpt
self.offload_model(self.current_model)
if vae_config := self._choose_diffusers_vae(model_name):
vae = self._load_vae(vae_config)
try:
if self.list_models()[self.current_model]['status'] == 'active':
self.offload_model(self.current_model)
except Exception as e:
pass
vae_path = None
if vae:
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
if self._has_cuda():
torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=weights,
original_config_file=config,
vae=vae,
vae_path=vae_path,
return_generator_pipeline=True,
precision=torch.float16 if self.precision == "float16" else torch.float32,
)
@ -512,6 +515,7 @@ class ModelManager(object):
print(f">> Offloading {model_name} to CPU")
model = self.models[model_name]["model"]
model.offload_all()
self.current_model = None
gc.collect()
if self._has_cuda():
@ -795,15 +799,16 @@ class ModelManager(object):
return model_name
def convert_and_import(
self,
ckpt_path: Path,
diffusers_path: Path,
model_name=None,
model_description=None,
vae=None,
original_config_file: Path = None,
commit_to_conf: Path = None,
scan_needed: bool=True,
self,
ckpt_path: Path,
diffusers_path: Path,
model_name=None,
model_description=None,
vae:dict=None,
vae_path:Path=None,
original_config_file: Path = None,
commit_to_conf: Path = None,
scan_needed: bool=True,
) -> str:
"""
Convert a legacy ckpt weights file to diffuser model and import
@ -831,13 +836,17 @@ class ModelManager(object):
try:
# By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file
vae_model = self._load_vae(vae) if vae else None
vae_model=None
if vae:
vae_model=self._load_vae(vae)
vae_path=None
convert_ckpt_to_diffusers(
ckpt_path,
diffusers_path,
extract_ema=True,
original_config_file=original_config_file,
vae=vae_model,
vae_path=vae_path,
scan_needed=scan_needed,
)
print(
@ -884,36 +893,6 @@ class ModelManager(object):
return search_folder, found_models
def _choose_diffusers_vae(
self, model_name: str, vae: str = None
) -> Union[dict, str]:
# In the event that the original entry is using a custom ckpt VAE, we try to
# map that VAE onto a diffuser VAE using a hard-coded dictionary.
# I would prefer to do this differently: We load the ckpt model into memory, swap the
# VAE in memory, and then pass that to convert_ckpt_to_diffuser() so that the swapped
# VAE is built into the model. However, when I tried this I got obscure key errors.
if vae:
return vae
if model_name in self.config and (
vae_ckpt_path := self.model_info(model_name).get("vae", None)
):
vae_basename = Path(vae_ckpt_path).stem
diffusers_vae = None
if diffusers_vae := VAE_TO_REPO_ID.get(vae_basename, None):
print(
f">> {vae_basename} VAE corresponds to known {diffusers_vae} diffusers version"
)
vae = {"repo_id": diffusers_vae}
else:
print(
f'** Custom VAE "{vae_basename}" found, but corresponding diffusers model unknown'
)
print(
'** Using "stabilityai/sd-vae-ft-mse"; If this isn\'t right, please edit the model config'
)
vae = {"repo_id": "stabilityai/sd-vae-ft-mse"}
return vae
def _make_cache_room(self) -> None:
num_loaded_models = len(self.models)
if num_loaded_models >= self.max_loaded_models:

View File

@ -772,16 +772,10 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
original_config_file = Path(model_info["config"])
model_name = model_name_or_path
model_description = model_info["description"]
vae = model_info["vae"]
vae_path = model_info.get("vae")
else:
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
return
if vae_repo := invokeai.backend.model_management.model_manager.VAE_TO_REPO_ID.get(
Path(vae).stem
):
vae_repo = dict(repo_id=vae_repo)
else:
vae_repo = None
model_name = manager.convert_and_import(
ckpt_path,
diffusers_path=Path(
@ -790,7 +784,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
model_name=model_name,
model_description=model_description,
original_config_file=original_config_file,
vae=vae_repo,
vae_path=vae_path,
)
else:
try: