make convert script respect setting of use_ema in config file

This commit is contained in:
Lincoln Stein 2023-07-29 17:17:45 -04:00
parent 2a2d988928
commit 3f9105be50

View File

@ -422,11 +422,8 @@ def convert_ldm_unet_checkpoint(
)
for key in keys:
if key.startswith("model.diffusion_model"):
for delimiter in ['','.']:
flat_ema_key = "model_ema." + delimiter.join(key.split(".")[1:])
if checkpoint.get(flat_ema_key) is not None:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
break
flat_ema_key = "model_ema." + "".join(key.split(".")[2:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100:
logger.warning(
@ -1114,7 +1111,6 @@ def convert_controlnet_checkpoint(
return controlnet.to(precision)
# TO DO - PASS PRECISION
def download_from_original_stable_diffusion_ckpt(
checkpoint_path: str,
model_version: BaseModelType,
@ -1288,6 +1284,9 @@ def download_from_original_stable_diffusion_ckpt(
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
if original_config['model']['params'].get('use_ema') is not None:
extract_ema = original_config['model']['params']['use_ema']
if (
model_version == BaseModelType.StableDiffusion2
and original_config["model"]["params"].get("parameterization") == "v"