InvokeAI/ldm/invoke/merge_diffusers.py

63 lines
3.3 KiB
Python

'''
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
'''
import os
from typing import List
from diffusers import DiffusionPipeline
from ldm.invoke.globals import global_config_file, global_models_dir, global_cache_dir
from ldm.invoke.model_manager import ModelManager
from omegaconf import OmegaConf
def merge_diffusion_models(models:List['str'],
merged_model_name:str,
alpha:float=0.5,
interp:str=None,
force:bool=False,
**kwargs):
'''
models - up to three models, designated by their InvokeAI models.yaml model name
merged_model_name = name for new model
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
'''
config_file = global_config_file()
model_manager = ModelManager(OmegaConf.load(config_file))
for mod in models:
assert (mod in model_manager.model_names()), f'** Unknown model "{mod}"'
assert (model_manager.model_info(mod).get('format',None) == 'diffusers'), f'** {mod} is not a diffusers model. It must be optimized before merging.'
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
pipe = DiffusionPipeline.from_pretrained(model_ids_or_paths[0],
cache_dir=kwargs.get('cache_dir',global_cache_dir()),
custom_pipeline='checkpoint_merger')
merged_pipe = pipe.merge(pretrained_model_name_or_path_list=model_ids_or_paths,
alpha=alpha,
interp=interp,
force=force,
**kwargs)
dump_path = global_models_dir() / 'merged_diffusers'
os.makedirs(dump_path,exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained (
dump_path,
safe_serialization=1
)
model_manager.import_diffuser_model(
dump_path,
model_name = merged_model_name,
description = f'Merge of models {", ".join(models)}'
)
print('REMINDER: When PR 2369 is merged, replace merge_diffusers.py line 56 with vae= argument to impormodel()')
if vae := model_manager.config[models[0]].get('vae',None):
print(f'>> Using configured VAE assigned to {models[0]}')
model_manager.config[merged_model_name]['vae'] = vae
model_manager.commit(config_file)