mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
create small module for merge importation logic
This commit is contained in:
parent
f0fe483915
commit
6c31225d19
@ -29,6 +29,7 @@ else:
|
|||||||
|
|
||||||
# Where to look for the initialization file
|
# Where to look for the initialization file
|
||||||
Globals.initfile = 'invokeai.init'
|
Globals.initfile = 'invokeai.init'
|
||||||
|
Globals.models_file = 'models.yaml'
|
||||||
Globals.models_dir = 'models'
|
Globals.models_dir = 'models'
|
||||||
Globals.config_dir = 'configs'
|
Globals.config_dir = 'configs'
|
||||||
Globals.autoscan_dir = 'weights'
|
Globals.autoscan_dir = 'weights'
|
||||||
@ -49,6 +50,9 @@ Globals.disable_xformers = False
|
|||||||
# whether we are forcing full precision
|
# whether we are forcing full precision
|
||||||
Globals.full_precision = False
|
Globals.full_precision = False
|
||||||
|
|
||||||
|
def global_config_file()->Path:
|
||||||
|
return Path(Globals.root, Globals.config_dir, Globals.models_file)
|
||||||
|
|
||||||
def global_config_dir()->Path:
|
def global_config_dir()->Path:
|
||||||
return Path(Globals.root, Globals.config_dir)
|
return Path(Globals.root, Globals.config_dir)
|
||||||
|
|
||||||
|
59
ldm/invoke/merge_diffusers.py
Normal file
59
ldm/invoke/merge_diffusers.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
'''
|
||||||
|
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
|
||||||
|
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
|
||||||
|
'''
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
from ldm.invoke.globals import global_config_file, global_models_dir, global_cache_dir
|
||||||
|
from ldm.invoke.model_manager import ModelManager
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
def merge_diffusion_models(models:List['str'],
|
||||||
|
merged_model_name:str,
|
||||||
|
alpha:float=0.5,
|
||||||
|
interp:str=None,
|
||||||
|
force:bool=False,
|
||||||
|
**kwargs):
|
||||||
|
'''
|
||||||
|
models - up to three models, designated by their InvokeAI models.yaml model name
|
||||||
|
merged_model_name = name for new model
|
||||||
|
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||||
|
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||||
|
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||||
|
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||||
|
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||||
|
|
||||||
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
|
'''
|
||||||
|
config_file = global_config_file()
|
||||||
|
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||||
|
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
||||||
|
|
||||||
|
pipe = DiffusionPipeline.from_pretrained(model_ids_or_paths[0],
|
||||||
|
cache_dir=kwargs.get('cache_dir',global_cache_dir()),
|
||||||
|
custom_pipeline='checkpoint_merger')
|
||||||
|
merged_pipe = pipe.merge(pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||||
|
alpha=alpha,
|
||||||
|
interp=interp,
|
||||||
|
force=force,
|
||||||
|
**kwargs)
|
||||||
|
dump_path = global_models_dir() / 'merged_diffusers'
|
||||||
|
os.makedirs(dump_path,exist_ok=True)
|
||||||
|
dump_path = dump_path / merged_model_name
|
||||||
|
merged_pipe.save_pretrained (
|
||||||
|
dump_path,
|
||||||
|
safe_serialization=1
|
||||||
|
)
|
||||||
|
model_manager.import_diffuser_model(
|
||||||
|
dump_path,
|
||||||
|
model_name = merged_model_name,
|
||||||
|
description = f'Merge of models {", ".join(models)}'
|
||||||
|
)
|
||||||
|
print('REMINDER: When PR 2369 is merged, replace merge_diffusers.py line 56 with vae= argument to impormodel()')
|
||||||
|
if vae := model_manager.config[models[0]].get('vae',None):
|
||||||
|
print(f'>> Using configured VAE assigned to {models[0]}')
|
||||||
|
model_manager.config[merged_model_name]['vae'] = vae
|
||||||
|
|
||||||
|
model_manager.commit(config_file)
|
@ -37,7 +37,11 @@ from ldm.util import instantiate_from_config, ask_user
|
|||||||
DEFAULT_MAX_MODELS=2
|
DEFAULT_MAX_MODELS=2
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
|
def __init__(self,
|
||||||
|
config:OmegaConf,
|
||||||
|
device_type:str='cpu',
|
||||||
|
precision:str='float16',
|
||||||
|
max_loaded_models=DEFAULT_MAX_MODELS):
|
||||||
'''
|
'''
|
||||||
Initialize with the path to the models.yaml config file,
|
Initialize with the path to the models.yaml config file,
|
||||||
the torch device type, and precision. The optional
|
the torch device type, and precision. The optional
|
||||||
@ -536,7 +540,7 @@ class ModelManager(object):
|
|||||||
format='diffusers',
|
format='diffusers',
|
||||||
)
|
)
|
||||||
if isinstance(repo_or_path,Path) and repo_or_path.exists():
|
if isinstance(repo_or_path,Path) and repo_or_path.exists():
|
||||||
new_config.update(path=repo_or_path)
|
new_config.update(path=str(repo_or_path))
|
||||||
else:
|
else:
|
||||||
new_config.update(repo_id=repo_or_path)
|
new_config.update(repo_id=repo_or_path)
|
||||||
|
|
||||||
|
0
scripts/load_models.py
Normal file → Executable file
0
scripts/load_models.py
Normal file → Executable file
0
scripts/merge_embeddings.py
Normal file → Executable file
0
scripts/merge_embeddings.py
Normal file → Executable file
@ -5,15 +5,12 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import argparse
|
import argparse
|
||||||
import safetensors.torch
|
from ldm.invoke.globals import Globals, global_set_root, global_cache_dir, global_config_file
|
||||||
from ldm.invoke.globals import Globals, global_set_root, global_cache_dir
|
|
||||||
from ldm.invoke.model_manager import ModelManager
|
from ldm.invoke.model_manager import ModelManager
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
CONFIG_FILE = None
|
|
||||||
|
|
||||||
class FloatSlider(npyscreen.Slider):
|
class FloatSlider(npyscreen.Slider):
|
||||||
# this is supposed to adjust display precision, but doesn't
|
# this is supposed to adjust display precision, but doesn't
|
||||||
def translate_value(self):
|
def translate_value(self):
|
||||||
@ -120,16 +117,16 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
self.merge_method.value=0
|
self.merge_method.value=0
|
||||||
|
|
||||||
def on_ok(self):
|
def on_ok(self):
|
||||||
if self.validate_field_values():
|
if self.validate_field_values() and self.check_for_overwrite():
|
||||||
self.parentApp.setNextForm(None)
|
self.parentApp.setNextForm(None)
|
||||||
self.editing = False
|
self.editing = False
|
||||||
self.parentApp.merge_arguments = self.marshall_arguments()
|
self.parentApp.merge_arguments = self.marshall_arguments()
|
||||||
npyscreen.notify('Starting the merge...')
|
npyscreen.notify('Starting the merge...')
|
||||||
import diffusers # this keeps the message up while diffusers loads
|
import ldm.invoke.merge_diffusers # this keeps the message up while diffusers loads
|
||||||
else:
|
else:
|
||||||
self.editing = True
|
self.editing = True
|
||||||
|
|
||||||
def ok_cancel(self):
|
def on_cancel(self):
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
def marshall_arguments(self)->dict:
|
def marshall_arguments(self)->dict:
|
||||||
@ -141,18 +138,22 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
if self.model3.value[0] > 0:
|
if self.model3.value[0] > 0:
|
||||||
models.append(model_names[self.model3.value[0]-1])
|
models.append(model_names[self.model3.value[0]-1])
|
||||||
|
|
||||||
models = [self.model_manager.model_name_or_path(x) for x in models]
|
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
pretrained_model_name_or_path_list=models,
|
models=models,
|
||||||
alpha = self.alpha.value,
|
alpha = self.alpha.value,
|
||||||
interp = self.interpolations[self.merge_method.value[0]],
|
interp = self.interpolations[self.merge_method.value[0]],
|
||||||
force = self.force.value,
|
force = self.force.value,
|
||||||
cache_dir = global_cache_dir('diffusers'),
|
|
||||||
merged_model_name = self.merged_model_name.value,
|
merged_model_name = self.merged_model_name.value,
|
||||||
)
|
)
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
def check_for_overwrite(self)->bool:
|
||||||
|
model_out = self.merged_model_name.value
|
||||||
|
if model_out not in self.model_names:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return npyscreen.notify_yes_no(f'The chosen merged model destination, {model_out}, is already in use. Overwrite?')
|
||||||
|
|
||||||
def validate_field_values(self)->bool:
|
def validate_field_values(self)->bool:
|
||||||
bad_fields = []
|
bad_fields = []
|
||||||
model_names = self.model_names
|
model_names = self.model_names
|
||||||
@ -178,7 +179,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
class Mergeapp(npyscreen.NPSAppManaged):
|
class Mergeapp(npyscreen.NPSAppManaged):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
conf = OmegaConf.load(Path(Globals.root) / 'configs' / 'models.yaml')
|
conf = OmegaConf.load(global_config_file())
|
||||||
self.model_manager = ModelManager(conf,'cpu','float16') # precision doesn't really matter here
|
self.model_manager = ModelManager(conf,'cpu','float16') # precision doesn't really matter here
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
@ -196,50 +197,21 @@ if __name__ == '__main__':
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
global_set_root(args.root_dir)
|
global_set_root(args.root_dir)
|
||||||
|
|
||||||
CONFIG_FILE = os.path.join(Globals.root,'configs/models.yaml')
|
cache_dir = str(global_cache_dir('diffusers')) # because not clear the merge pipeline is honoring cache_dir
|
||||||
os.environ['HF_HOME'] = str(global_cache_dir('diffusers'))
|
os.environ['HF_HOME'] = cache_dir
|
||||||
|
|
||||||
mergeapp = Mergeapp()
|
mergeapp = Mergeapp()
|
||||||
mergeapp.run()
|
mergeapp.run()
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
args = mergeapp.merge_arguments
|
args = mergeapp.merge_arguments
|
||||||
merged_model_name = args['merged_model_name']
|
args.update(cache_dir = cache_dir)
|
||||||
merged_pipe = None
|
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||||
print(args)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(f'DEBUG: {args["pretrained_model_name_or_path_list"][0]}')
|
merge_diffusion_models(**args)
|
||||||
pipe = DiffusionPipeline.from_pretrained(args['pretrained_model_name_or_path_list'][0],
|
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
|
||||||
custom_pipeline='checkpoint_merger'
|
|
||||||
)
|
|
||||||
merged_pipe = pipe.merge(**args)
|
|
||||||
dump_path = Path(Globals.root) / 'models' / 'merged_diffusers'
|
|
||||||
os.makedirs(dump_path,exist_ok=True)
|
|
||||||
dump_path = dump_path / merged_model_name
|
|
||||||
merged_pipe.save_pretrained (
|
|
||||||
dump_path,
|
|
||||||
safe_serialization=1
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'** An error occurred while merging the pipelines: {str(e)}')
|
print(f'** An error occurred while merging the pipelines: {str(e)}')
|
||||||
print('** DETAILS:')
|
print('** DETAILS:')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
print(f'>> Merged model is saved to {dump_path}')
|
|
||||||
response = input('Import this model into InvokeAI? [y]').strip() or 'y'
|
|
||||||
if response.startswith(('y','Y')):
|
|
||||||
try:
|
|
||||||
mergeapp.model_manager.import_diffuser_model(
|
|
||||||
dump_path,
|
|
||||||
model_name = merged_model_name,
|
|
||||||
description = f'Merge of models {args["pretrained_model_name_or_path_list"]}'
|
|
||||||
)
|
|
||||||
mergeapp.model_manager.commit(CONFIG_FILE)
|
|
||||||
print('>> Merged model imported.')
|
|
||||||
except Exception as e:
|
|
||||||
print(f'** New model could not be committed to config.yaml: {str(e)}')
|
|
||||||
print('** DETAILS:')
|
|
||||||
print(traceback.format_exc())
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user