mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make convert script respect setting of use_ema in config file
This commit is contained in:
parent
2a2d988928
commit
3f9105be50
@ -422,11 +422,8 @@ def convert_ldm_unet_checkpoint(
|
||||
)
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
for delimiter in ['','.']:
|
||||
flat_ema_key = "model_ema." + delimiter.join(key.split(".")[1:])
|
||||
if checkpoint.get(flat_ema_key) is not None:
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[2:])
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||
break
|
||||
else:
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
logger.warning(
|
||||
@ -1114,7 +1111,6 @@ def convert_controlnet_checkpoint(
|
||||
return controlnet.to(precision)
|
||||
|
||||
|
||||
# TO DO - PASS PRECISION
|
||||
def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path: str,
|
||||
model_version: BaseModelType,
|
||||
@ -1288,6 +1284,9 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
if original_config['model']['params'].get('use_ema') is not None:
|
||||
extract_ema = original_config['model']['params']['use_ema']
|
||||
|
||||
if (
|
||||
model_version == BaseModelType.StableDiffusion2
|
||||
and original_config["model"]["params"].get("parameterization") == "v"
|
||||
|
Loading…
Reference in New Issue
Block a user