mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make convert script respect setting of use_ema in config file
This commit is contained in:
parent
2a2d988928
commit
3f9105be50
@ -422,11 +422,8 @@ def convert_ldm_unet_checkpoint(
|
|||||||
)
|
)
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("model.diffusion_model"):
|
if key.startswith("model.diffusion_model"):
|
||||||
for delimiter in ['','.']:
|
flat_ema_key = "model_ema." + "".join(key.split(".")[2:])
|
||||||
flat_ema_key = "model_ema." + delimiter.join(key.split(".")[1:])
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||||
if checkpoint.get(flat_ema_key) is not None:
|
|
||||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -1114,7 +1111,6 @@ def convert_controlnet_checkpoint(
|
|||||||
return controlnet.to(precision)
|
return controlnet.to(precision)
|
||||||
|
|
||||||
|
|
||||||
# TO DO - PASS PRECISION
|
|
||||||
def download_from_original_stable_diffusion_ckpt(
|
def download_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
model_version: BaseModelType,
|
model_version: BaseModelType,
|
||||||
@ -1288,6 +1284,9 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
original_config_file = BytesIO(requests.get(config_url).content)
|
original_config_file = BytesIO(requests.get(config_url).content)
|
||||||
|
|
||||||
original_config = OmegaConf.load(original_config_file)
|
original_config = OmegaConf.load(original_config_file)
|
||||||
|
if original_config['model']['params'].get('use_ema') is not None:
|
||||||
|
extract_ema = original_config['model']['params']['use_ema']
|
||||||
|
|
||||||
if (
|
if (
|
||||||
model_version == BaseModelType.StableDiffusion2
|
model_version == BaseModelType.StableDiffusion2
|
||||||
and original_config["model"]["params"].get("parameterization") == "v"
|
and original_config["model"]["params"].get("parameterization") == "v"
|
||||||
|
Loading…
Reference in New Issue
Block a user