mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
working, but needs diffusers PR to be accepted
This commit is contained in:
parent
f169bb0020
commit
4ee8d104f0
111
scripts/merge_fe.py
Normal file → Executable file
111
scripts/merge_fe.py
Normal file → Executable file
@ -3,15 +3,17 @@
|
|||||||
import npyscreen
|
import npyscreen
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import re
|
|
||||||
import shutil
|
|
||||||
import traceback
|
import traceback
|
||||||
import argparse
|
import argparse
|
||||||
from ldm.invoke.globals import Globals, global_set_root
|
import safetensors.torch
|
||||||
|
from ldm.invoke.globals import Globals, global_set_root, global_cache_dir
|
||||||
|
from ldm.invoke.model_manager import ModelManager
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
CONFIG_FILE = None
|
||||||
|
|
||||||
class FloatSlider(npyscreen.Slider):
|
class FloatSlider(npyscreen.Slider):
|
||||||
# this is supposed to adjust display precision, but doesn't
|
# this is supposed to adjust display precision, but doesn't
|
||||||
def translate_value(self):
|
def translate_value(self):
|
||||||
@ -30,6 +32,14 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
'inv_sigmoid',
|
'inv_sigmoid',
|
||||||
'add_difference']
|
'add_difference']
|
||||||
|
|
||||||
|
def __init__(self, parentApp, name):
|
||||||
|
self.parentApp = parentApp
|
||||||
|
super().__init__(parentApp, name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_manager(self):
|
||||||
|
return self.parentApp.model_manager
|
||||||
|
|
||||||
def afterEditing(self):
|
def afterEditing(self):
|
||||||
self.parentApp.setNextForm(None)
|
self.parentApp.setNextForm(None)
|
||||||
|
|
||||||
@ -83,6 +93,11 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
lowest=0,
|
lowest=0,
|
||||||
value=0.5,
|
value=0.5,
|
||||||
)
|
)
|
||||||
|
self.force = self.add_widget_intelligent(
|
||||||
|
npyscreen.Checkbox,
|
||||||
|
name='Force merge of incompatible models',
|
||||||
|
value=False,
|
||||||
|
)
|
||||||
self.merged_model_name = self.add_widget_intelligent(
|
self.merged_model_name = self.add_widget_intelligent(
|
||||||
npyscreen.TitleText,
|
npyscreen.TitleText,
|
||||||
name='Name for merged model',
|
name='Name for merged model',
|
||||||
@ -108,17 +123,44 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
if self.validate_field_values():
|
if self.validate_field_values():
|
||||||
self.parentApp.setNextForm(None)
|
self.parentApp.setNextForm(None)
|
||||||
self.editing = False
|
self.editing = False
|
||||||
|
self.parentApp.merge_arguments = self.marshall_arguments()
|
||||||
|
npyscreen.notify('Starting the merge...')
|
||||||
|
import diffusers # this keeps the message up while diffusers loads
|
||||||
else:
|
else:
|
||||||
self.editing = True
|
self.editing = True
|
||||||
|
|
||||||
def ok_cancel(self):
|
def ok_cancel(self):
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
def marshall_arguments(self)->dict:
|
||||||
|
model_names = self.model_names
|
||||||
|
models = [
|
||||||
|
model_names[self.model1.value[0]],
|
||||||
|
model_names[self.model2.value[0]],
|
||||||
|
]
|
||||||
|
if self.model3.value[0] > 0:
|
||||||
|
models.append(model_names[self.model3.value[0]-1])
|
||||||
|
|
||||||
|
models = [self.model_manager.model_name_or_path(x) for x in models]
|
||||||
|
|
||||||
|
args = dict(
|
||||||
|
pretrained_model_name_or_path_list=models,
|
||||||
|
alpha = self.alpha.value,
|
||||||
|
interp = self.interpolations[self.merge_method.value[0]],
|
||||||
|
force = self.force.value,
|
||||||
|
cache_dir = global_cache_dir('diffusers'),
|
||||||
|
merged_model_name = self.merged_model_name.value,
|
||||||
|
)
|
||||||
|
return args
|
||||||
|
|
||||||
def validate_field_values(self)->bool:
|
def validate_field_values(self)->bool:
|
||||||
bad_fields = []
|
bad_fields = []
|
||||||
selected_models = set((self.model1.value[0],self.model2.value[0],self.model3.value[0]))
|
model_names = self.model_names
|
||||||
if len(selected_models) < 3:
|
selected_models = set((model_names[self.model1.value[0]],model_names[self.model2.value[0]]))
|
||||||
bad_fields.append('Please select two or three DIFFERENT models to compare')
|
if self.model3.value[0] > 0:
|
||||||
|
selected_models.add(model_names[self.model3.value[0]-1])
|
||||||
|
if len(selected_models) < 2:
|
||||||
|
bad_fields.append(f'Please select two or three DIFFERENT models to compare. You selected {selected_models}')
|
||||||
if len(bad_fields) > 0:
|
if len(bad_fields) > 0:
|
||||||
message = 'The following problems were detected and must be corrected:'
|
message = 'The following problems were detected and must be corrected:'
|
||||||
for problem in bad_fields:
|
for problem in bad_fields:
|
||||||
@ -129,13 +171,15 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def get_model_names(self)->List[str]:
|
def get_model_names(self)->List[str]:
|
||||||
conf = OmegaConf.load(os.path.join(Globals.root,'configs/models.yaml'))
|
model_names = [name for name in self.model_manager.model_names() if self.model_manager.model_info(name).get('format') == 'diffusers']
|
||||||
model_names = [name for name in conf.keys() if conf[name].get('format',None)=='diffusers']
|
print(model_names)
|
||||||
return sorted(model_names)
|
return sorted(model_names)
|
||||||
|
|
||||||
class MyApplication(npyscreen.NPSAppManaged):
|
class Mergeapp(npyscreen.NPSAppManaged):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
conf = OmegaConf.load(Path(Globals.root) / 'configs' / 'models.yaml')
|
||||||
|
self.model_manager = ModelManager(conf,'cpu','float16') # precision doesn't really matter here
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||||
@ -151,6 +195,51 @@ if __name__ == '__main__':
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
global_set_root(args.root_dir)
|
global_set_root(args.root_dir)
|
||||||
|
|
||||||
|
CONFIG_FILE = os.path.join(Globals.root,'configs/models.yaml')
|
||||||
|
os.environ['HF_HOME'] = str(global_cache_dir('diffusers'))
|
||||||
|
|
||||||
|
mergeapp = Mergeapp()
|
||||||
|
mergeapp.run()
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
args = mergeapp.merge_arguments
|
||||||
|
merged_model_name = args['merged_model_name']
|
||||||
|
merged_pipe = None
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f'DEBUG: {args["pretrained_model_name_or_path_list"][0]}')
|
||||||
|
pipe = DiffusionPipeline.from_pretrained(args['pretrained_model_name_or_path_list'][0],
|
||||||
|
custom_pipeline='checkpoint_merger'
|
||||||
|
)
|
||||||
|
merged_pipe = pipe.merge(**args)
|
||||||
|
dump_path = Path(Globals.root) / 'models' / 'merged_diffusers'
|
||||||
|
os.makedirs(dump_path,exist_ok=True)
|
||||||
|
dump_path = dump_path / merged_model_name
|
||||||
|
merged_pipe.save_pretrained (
|
||||||
|
dump_path,
|
||||||
|
safe_serialization=1
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'** An error occurred while merging the pipelines: {str(e)}')
|
||||||
|
print('** DETAILS:')
|
||||||
|
print(traceback.format_exc())
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
print(f'>> Merged model is saved to {dump_path}')
|
||||||
|
response = input('Import this model into InvokeAI? [y]').strip() or 'y'
|
||||||
|
if response.startswith(('y','Y')):
|
||||||
|
try:
|
||||||
|
mergeapp.model_manager.import_diffuser_model(
|
||||||
|
dump_path,
|
||||||
|
model_name = merged_model_name,
|
||||||
|
description = f'Merge of models {args["pretrained_model_name_or_path_list"]}'
|
||||||
|
)
|
||||||
|
mergeapp.model_manager.commit(CONFIG_FILE)
|
||||||
|
print('>> Merged model imported.')
|
||||||
|
except Exception as e:
|
||||||
|
print(f'** New model could not be committed to config.yaml: {str(e)}')
|
||||||
|
print('** DETAILS:')
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
myapplication = MyApplication()
|
|
||||||
myapplication.run()
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user