create small module for merge importation logic

This commit is contained in:
Lincoln Stein 2023-01-22 18:07:53 -05:00
parent f0fe483915
commit 6c31225d19
6 changed files with 91 additions and 52 deletions

View File

@ -29,6 +29,7 @@ else:
# Where to look for the initialization file # Where to look for the initialization file
Globals.initfile = 'invokeai.init' Globals.initfile = 'invokeai.init'
Globals.models_file = 'models.yaml'
Globals.models_dir = 'models' Globals.models_dir = 'models'
Globals.config_dir = 'configs' Globals.config_dir = 'configs'
Globals.autoscan_dir = 'weights' Globals.autoscan_dir = 'weights'
@ -49,6 +50,9 @@ Globals.disable_xformers = False
# whether we are forcing full precision # whether we are forcing full precision
Globals.full_precision = False Globals.full_precision = False
def global_config_file()->Path:
return Path(Globals.root, Globals.config_dir, Globals.models_file)
def global_config_dir()->Path: def global_config_dir()->Path:
return Path(Globals.root, Globals.config_dir) return Path(Globals.root, Globals.config_dir)

View File

@ -0,0 +1,59 @@
'''
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
'''
import os
from typing import List
from diffusers import DiffusionPipeline
from ldm.invoke.globals import global_config_file, global_models_dir, global_cache_dir
from ldm.invoke.model_manager import ModelManager
from omegaconf import OmegaConf
def merge_diffusion_models(models:List['str'],
merged_model_name:str,
alpha:float=0.5,
interp:str=None,
force:bool=False,
**kwargs):
'''
models - up to three models, designated by their InvokeAI models.yaml model name
merged_model_name = name for new model
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
'''
config_file = global_config_file()
model_manager = ModelManager(OmegaConf.load(config_file))
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
pipe = DiffusionPipeline.from_pretrained(model_ids_or_paths[0],
cache_dir=kwargs.get('cache_dir',global_cache_dir()),
custom_pipeline='checkpoint_merger')
merged_pipe = pipe.merge(pretrained_model_name_or_path_list=model_ids_or_paths,
alpha=alpha,
interp=interp,
force=force,
**kwargs)
dump_path = global_models_dir() / 'merged_diffusers'
os.makedirs(dump_path,exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained (
dump_path,
safe_serialization=1
)
model_manager.import_diffuser_model(
dump_path,
model_name = merged_model_name,
description = f'Merge of models {", ".join(models)}'
)
print('REMINDER: When PR 2369 is merged, replace merge_diffusers.py line 56 with vae= argument to impormodel()')
if vae := model_manager.config[models[0]].get('vae',None):
print(f'>> Using configured VAE assigned to {models[0]}')
model_manager.config[merged_model_name]['vae'] = vae
model_manager.commit(config_file)

View File

@ -37,7 +37,11 @@ from ldm.util import instantiate_from_config, ask_user
DEFAULT_MAX_MODELS=2 DEFAULT_MAX_MODELS=2
class ModelManager(object): class ModelManager(object):
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS): def __init__(self,
config:OmegaConf,
device_type:str='cpu',
precision:str='float16',
max_loaded_models=DEFAULT_MAX_MODELS):
''' '''
Initialize with the path to the models.yaml config file, Initialize with the path to the models.yaml config file,
the torch device type, and precision. The optional the torch device type, and precision. The optional
@ -536,7 +540,7 @@ class ModelManager(object):
format='diffusers', format='diffusers',
) )
if isinstance(repo_or_path,Path) and repo_or_path.exists(): if isinstance(repo_or_path,Path) and repo_or_path.exists():
new_config.update(path=repo_or_path) new_config.update(path=str(repo_or_path))
else: else:
new_config.update(repo_id=repo_or_path) new_config.update(repo_id=repo_or_path)

0
scripts/load_models.py Normal file → Executable file
View File

0
scripts/merge_embeddings.py Normal file → Executable file
View File

View File

@ -5,15 +5,12 @@ import os
import sys import sys
import traceback import traceback
import argparse import argparse
import safetensors.torch from ldm.invoke.globals import Globals, global_set_root, global_cache_dir, global_config_file
from ldm.invoke.globals import Globals, global_set_root, global_cache_dir
from ldm.invoke.model_manager import ModelManager from ldm.invoke.model_manager import ModelManager
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pathlib import Path from pathlib import Path
from typing import List from typing import List
CONFIG_FILE = None
class FloatSlider(npyscreen.Slider): class FloatSlider(npyscreen.Slider):
# this is supposed to adjust display precision, but doesn't # this is supposed to adjust display precision, but doesn't
def translate_value(self): def translate_value(self):
@ -120,16 +117,16 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
self.merge_method.value=0 self.merge_method.value=0
def on_ok(self): def on_ok(self):
if self.validate_field_values(): if self.validate_field_values() and self.check_for_overwrite():
self.parentApp.setNextForm(None) self.parentApp.setNextForm(None)
self.editing = False self.editing = False
self.parentApp.merge_arguments = self.marshall_arguments() self.parentApp.merge_arguments = self.marshall_arguments()
npyscreen.notify('Starting the merge...') npyscreen.notify('Starting the merge...')
import diffusers # this keeps the message up while diffusers loads import ldm.invoke.merge_diffusers # this keeps the message up while diffusers loads
else: else:
self.editing = True self.editing = True
def ok_cancel(self): def on_cancel(self):
sys.exit(0) sys.exit(0)
def marshall_arguments(self)->dict: def marshall_arguments(self)->dict:
@ -141,18 +138,22 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
if self.model3.value[0] > 0: if self.model3.value[0] > 0:
models.append(model_names[self.model3.value[0]-1]) models.append(model_names[self.model3.value[0]-1])
models = [self.model_manager.model_name_or_path(x) for x in models]
args = dict( args = dict(
pretrained_model_name_or_path_list=models, models=models,
alpha = self.alpha.value, alpha = self.alpha.value,
interp = self.interpolations[self.merge_method.value[0]], interp = self.interpolations[self.merge_method.value[0]],
force = self.force.value, force = self.force.value,
cache_dir = global_cache_dir('diffusers'),
merged_model_name = self.merged_model_name.value, merged_model_name = self.merged_model_name.value,
) )
return args return args
def check_for_overwrite(self)->bool:
model_out = self.merged_model_name.value
if model_out not in self.model_names:
return True
else:
return npyscreen.notify_yes_no(f'The chosen merged model destination, {model_out}, is already in use. Overwrite?')
def validate_field_values(self)->bool: def validate_field_values(self)->bool:
bad_fields = [] bad_fields = []
model_names = self.model_names model_names = self.model_names
@ -178,7 +179,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
class Mergeapp(npyscreen.NPSAppManaged): class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
conf = OmegaConf.load(Path(Globals.root) / 'configs' / 'models.yaml') conf = OmegaConf.load(global_config_file())
self.model_manager = ModelManager(conf,'cpu','float16') # precision doesn't really matter here self.model_manager = ModelManager(conf,'cpu','float16') # precision doesn't really matter here
def onStart(self): def onStart(self):
@ -196,50 +197,21 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
global_set_root(args.root_dir) global_set_root(args.root_dir)
CONFIG_FILE = os.path.join(Globals.root,'configs/models.yaml') cache_dir = str(global_cache_dir('diffusers')) # because not clear the merge pipeline is honoring cache_dir
os.environ['HF_HOME'] = str(global_cache_dir('diffusers')) os.environ['HF_HOME'] = cache_dir
mergeapp = Mergeapp() mergeapp = Mergeapp()
mergeapp.run() mergeapp.run()
from diffusers import DiffusionPipeline
args = mergeapp.merge_arguments args = mergeapp.merge_arguments
merged_model_name = args['merged_model_name'] args.update(cache_dir = cache_dir)
merged_pipe = None from ldm.invoke.merge_diffusers import merge_diffusion_models
print(args)
try: try:
print(f'DEBUG: {args["pretrained_model_name_or_path_list"][0]}') merge_diffusion_models(**args)
pipe = DiffusionPipeline.from_pretrained(args['pretrained_model_name_or_path_list'][0], print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
custom_pipeline='checkpoint_merger'
)
merged_pipe = pipe.merge(**args)
dump_path = Path(Globals.root) / 'models' / 'merged_diffusers'
os.makedirs(dump_path,exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained (
dump_path,
safe_serialization=1
)
except Exception as e: except Exception as e:
print(f'** An error occurred while merging the pipelines: {str(e)}') print(f'** An error occurred while merging the pipelines: {str(e)}')
print('** DETAILS:') print('** DETAILS:')
print(traceback.format_exc()) print(traceback.format_exc())
sys.exit(-1) sys.exit(-1)
print(f'>> Merged model is saved to {dump_path}')
response = input('Import this model into InvokeAI? [y]').strip() or 'y'
if response.startswith(('y','Y')):
try:
mergeapp.model_manager.import_diffuser_model(
dump_path,
model_name = merged_model_name,
description = f'Merge of models {args["pretrained_model_name_or_path_list"]}'
)
mergeapp.model_manager.commit(CONFIG_FILE)
print('>> Merged model imported.')
except Exception as e:
print(f'** New model could not be committed to config.yaml: {str(e)}')
print('** DETAILS:')
print(traceback.format_exc())