mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
speculative fix for alternative vaes
This commit is contained in:
parent
d9dab1b6c7
commit
3ca654d256
@ -1026,6 +1026,14 @@ def convert_open_clip_checkpoint(checkpoint):
|
|||||||
|
|
||||||
return text_model
|
return text_model
|
||||||
|
|
||||||
|
def replace_checkpoint_vae(checkpoint, vae_path:str):
|
||||||
|
if vae_path.endswith(".safetensors"):
|
||||||
|
vae_ckpt = load_file(vae_path)
|
||||||
|
else:
|
||||||
|
vae_ckpt = torch.load(vae_path, map_location="cpu")
|
||||||
|
for vae_key in vae_ckpt['state_dict']:
|
||||||
|
new_key = f'first_stage_model.{vae_key}'
|
||||||
|
checkpoint[new_key] = vae_ckpt['state_dict'][vae_key]
|
||||||
|
|
||||||
def load_pipeline_from_original_stable_diffusion_ckpt(
|
def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
@ -1038,6 +1046,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
extract_ema: bool = True,
|
extract_ema: bool = True,
|
||||||
upcast_attn: bool = False,
|
upcast_attn: bool = False,
|
||||||
vae: AutoencoderKL = None,
|
vae: AutoencoderKL = None,
|
||||||
|
vae_path: str = None,
|
||||||
precision: torch.dtype = torch.float32,
|
precision: torch.dtype = torch.float32,
|
||||||
return_generator_pipeline: bool = False,
|
return_generator_pipeline: bool = False,
|
||||||
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
|
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
|
||||||
@ -1067,6 +1076,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
|
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
|
||||||
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
|
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
|
||||||
running stable diffusion 2.1.
|
running stable diffusion 2.1.
|
||||||
|
:param vae: A diffusers VAE to load into the pipeline.
|
||||||
|
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@ -1201,9 +1212,19 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
|
|
||||||
unet.load_state_dict(converted_unet_checkpoint)
|
unet.load_state_dict(converted_unet_checkpoint)
|
||||||
|
|
||||||
# Convert the VAE model, or use the one passed
|
# If a replacement VAE path was specified, we'll incorporate that into
|
||||||
if not vae:
|
# the checkpoint model and then convert it
|
||||||
|
if vae_path:
|
||||||
|
print(f" | Converting VAE {vae_path}")
|
||||||
|
replace_checkpoint_vae(checkpoint,vae_path)
|
||||||
|
# otherwise we use the original VAE, provided that
|
||||||
|
# an externally loaded diffusers VAE was not passed
|
||||||
|
elif not vae:
|
||||||
print(" | Using checkpoint model's original VAE")
|
print(" | Using checkpoint model's original VAE")
|
||||||
|
|
||||||
|
if vae:
|
||||||
|
print(" | Using replacement diffusers VAE")
|
||||||
|
else: # convert the original or replacement VAE
|
||||||
vae_config = create_vae_diffusers_config(
|
vae_config = create_vae_diffusers_config(
|
||||||
original_config, image_size=image_size
|
original_config, image_size=image_size
|
||||||
)
|
)
|
||||||
@ -1213,8 +1234,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
else:
|
|
||||||
print(" | Using external VAE specified in config")
|
|
||||||
|
|
||||||
# Convert the text model.
|
# Convert the text model.
|
||||||
model_type = pipeline_type
|
model_type = pipeline_type
|
||||||
|
@ -45,9 +45,6 @@ class SDLegacyType(Enum):
|
|||||||
UNKNOWN = 99
|
UNKNOWN = 99
|
||||||
|
|
||||||
DEFAULT_MAX_MODELS = 2
|
DEFAULT_MAX_MODELS = 2
|
||||||
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
|
||||||
"vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse",
|
|
||||||
}
|
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
'''
|
'''
|
||||||
@ -458,14 +455,15 @@ class ModelManager(object):
|
|||||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||||
|
|
||||||
self.offload_model(self.current_model)
|
self.offload_model(self.current_model)
|
||||||
if vae_config := self._choose_diffusers_vae(model_name):
|
vae_path = None
|
||||||
vae = self._load_vae(vae_config)
|
if vae:
|
||||||
|
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path=weights,
|
checkpoint_path=weights,
|
||||||
original_config_file=config,
|
original_config_file=config,
|
||||||
vae=vae,
|
vae_path=vae_path,
|
||||||
return_generator_pipeline=True,
|
return_generator_pipeline=True,
|
||||||
precision=torch.float16 if self.precision == "float16" else torch.float32,
|
precision=torch.float16 if self.precision == "float16" else torch.float32,
|
||||||
)
|
)
|
||||||
@ -519,7 +517,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
def scan_model(self, model_name, checkpoint):
|
def scan_model(self, model_name, checkpoint):
|
||||||
"""
|
"""
|
||||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
v Apply picklescanner to the indicated checkpoint and issue a warning
|
||||||
and option to exit if an infected file is identified.
|
and option to exit if an infected file is identified.
|
||||||
"""
|
"""
|
||||||
# scan model
|
# scan model
|
||||||
@ -879,36 +877,6 @@ class ModelManager(object):
|
|||||||
|
|
||||||
return search_folder, found_models
|
return search_folder, found_models
|
||||||
|
|
||||||
def _choose_diffusers_vae(
|
|
||||||
self, model_name: str, vae: str = None
|
|
||||||
) -> Union[dict, str]:
|
|
||||||
# In the event that the original entry is using a custom ckpt VAE, we try to
|
|
||||||
# map that VAE onto a diffuser VAE using a hard-coded dictionary.
|
|
||||||
# I would prefer to do this differently: We load the ckpt model into memory, swap the
|
|
||||||
# VAE in memory, and then pass that to convert_ckpt_to_diffuser() so that the swapped
|
|
||||||
# VAE is built into the model. However, when I tried this I got obscure key errors.
|
|
||||||
if vae:
|
|
||||||
return vae
|
|
||||||
if model_name in self.config and (
|
|
||||||
vae_ckpt_path := self.model_info(model_name).get("vae", None)
|
|
||||||
):
|
|
||||||
vae_basename = Path(vae_ckpt_path).stem
|
|
||||||
diffusers_vae = None
|
|
||||||
if diffusers_vae := VAE_TO_REPO_ID.get(vae_basename, None):
|
|
||||||
print(
|
|
||||||
f">> {vae_basename} VAE corresponds to known {diffusers_vae} diffusers version"
|
|
||||||
)
|
|
||||||
vae = {"repo_id": diffusers_vae}
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f'** Custom VAE "{vae_basename}" found, but corresponding diffusers model unknown'
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
'** Using "stabilityai/sd-vae-ft-mse"; If this isn\'t right, please edit the model config'
|
|
||||||
)
|
|
||||||
vae = {"repo_id": "stabilityai/sd-vae-ft-mse"}
|
|
||||||
return vae
|
|
||||||
|
|
||||||
def _make_cache_room(self) -> None:
|
def _make_cache_room(self) -> None:
|
||||||
num_loaded_models = len(self.models)
|
num_loaded_models = len(self.models)
|
||||||
if num_loaded_models >= self.max_loaded_models:
|
if num_loaded_models >= self.max_loaded_models:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user