mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[Feature] Add interactive diffusers model merger (#2388)
This PR adds `scripts/merge_fe.py`, which will merge any 2-3 diffusers models registered in InvokeAI's `models.yaml`, producing a new merged model that will be registered as well. Currently this script will only work if all models to be merged are known by their repo_ids. Local models, including those converted from ckpt files, will cause a crash due to a bug in the diffusers `checkpoint_merger.py` code. I have made a PR against huggingface/diffusers which fixes this: https://github.com/huggingface/diffusers/pull/2060
This commit is contained in:
commit
aca1b61413
77
docs/features/MODEL_MERGING.md
Normal file
77
docs/features/MODEL_MERGING.md
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
---
|
||||||
|
title: Model Merging
|
||||||
|
---
|
||||||
|
|
||||||
|
# :material-image-off: Model Merging
|
||||||
|
|
||||||
|
## How to Merge Models
|
||||||
|
|
||||||
|
As of version 2.3, InvokeAI comes with a script that allows you to
|
||||||
|
merge two or three diffusers-type models into a new merged model. The
|
||||||
|
resulting model will combine characteristics of the original, and can
|
||||||
|
be used to teach an old model new tricks.
|
||||||
|
|
||||||
|
You may run the merge script by starting the invoke launcher
|
||||||
|
(`invoke.sh` or `invoke.bat`) and choosing the option for _merge
|
||||||
|
models_. This will launch a text-based interactive user interface that
|
||||||
|
prompts you to select the models to merge, how to merge them, and the
|
||||||
|
merged model name.
|
||||||
|
|
||||||
|
Alternatively you may activate InvokeAI's virtual environment from the
|
||||||
|
command line, and call the script via `merge_models_fe.py` (the "fe"
|
||||||
|
stands for "front end"). There is also a version that accepts
|
||||||
|
command-line arguments, which you can run with the command
|
||||||
|
`merge_models.py`.
|
||||||
|
|
||||||
|
The user interface for the text-based interactive script is
|
||||||
|
straightforward. It shows you a series of setting fields. Use control-N (^N)
|
||||||
|
to move to the next field, and control-P (^P) to move to the previous
|
||||||
|
one. You can also use TAB and shift-TAB to move forward and
|
||||||
|
backward. Once you are in a multiple choice field, use the up and down
|
||||||
|
cursor arrows to move to your desired selection, and press <SPACE> or
|
||||||
|
<ENTER> to select it. Change text fields by typing in them, and adjust
|
||||||
|
scrollbars using the left and right arrow keys.
|
||||||
|
|
||||||
|
Once you are happy with your settings, press the OK button. Note that
|
||||||
|
there may be two pages of settings, depending on the height of your
|
||||||
|
screen, and the OK button may be on the second page. Advance past the
|
||||||
|
last field of the first page to get to the second page, and reverse
|
||||||
|
this to get back.
|
||||||
|
|
||||||
|
If the merge runs successfully, it will create a new diffusers model
|
||||||
|
under the selected name and register it with InvokeAI.
|
||||||
|
|
||||||
|
## The Settings
|
||||||
|
|
||||||
|
* Model Selection -- there are three multiple choice fields that
|
||||||
|
display all the diffusers-style models that InvokeAI knows about.
|
||||||
|
If you do not see the model you are looking for, then it is probably
|
||||||
|
a legacy checkpoint model and needs to be converted using the
|
||||||
|
`invoke.py` command-line client and its `!optimize` command. You
|
||||||
|
must select at least two models to merge. The third can be left at
|
||||||
|
"None" if you desire.
|
||||||
|
|
||||||
|
* Alpha -- This is the ratio to use when combining models. It ranges
|
||||||
|
from 0 to 1. The higher the value, the more weight is given to the
|
||||||
|
2d and (optionally) 3d models. So if you have two models named "A"
|
||||||
|
and "B", an alpha value of 0.25 will give you a merged model that is
|
||||||
|
25% A and 75% B.
|
||||||
|
|
||||||
|
* Interpolation Method -- This is the method used to combine
|
||||||
|
weights. The options are "weighted_sum" (the default), "sigmoid",
|
||||||
|
"inv_sigmoid" and "add_difference". Each produces slightly different
|
||||||
|
results. When three models are in use, only "add_difference" is
|
||||||
|
available. (TODO: cite a reference that describes what these
|
||||||
|
interpolation methods actually do and how to decide among them).
|
||||||
|
|
||||||
|
* Force -- Not all models are compatible with each other. The merge
|
||||||
|
script will check for compatibility and refuse to merge ones that
|
||||||
|
are incompatible. Set this checkbox to try merging anyway.
|
||||||
|
|
||||||
|
* Name for merged model - This is the name for the new model. Please
|
||||||
|
use InvokeAI conventions - only alphanumeric letters and the
|
||||||
|
characters ".+-".
|
||||||
|
|
||||||
|
## Caveats
|
||||||
|
|
||||||
|
This is a new script and may contain bugs.
|
@ -157,6 +157,8 @@ images in full-precision mode:
|
|||||||
<!-- seperator -->
|
<!-- seperator -->
|
||||||
- [Prompt Engineering](features/PROMPTS.md)
|
- [Prompt Engineering](features/PROMPTS.md)
|
||||||
<!-- seperator -->
|
<!-- seperator -->
|
||||||
|
- [Model Merging](features/MODEL_MERGING.md)
|
||||||
|
<!-- seperator -->
|
||||||
- Miscellaneous
|
- Miscellaneous
|
||||||
- [NSFW Checker](features/NSFW.md)
|
- [NSFW Checker](features/NSFW.md)
|
||||||
- [Embiggen upscaling](features/EMBIGGEN.md)
|
- [Embiggen upscaling](features/EMBIGGEN.md)
|
||||||
|
@ -10,8 +10,9 @@ echo Do you want to generate images using the
|
|||||||
echo 1. command-line
|
echo 1. command-line
|
||||||
echo 2. browser-based UI
|
echo 2. browser-based UI
|
||||||
echo 3. run textual inversion training
|
echo 3. run textual inversion training
|
||||||
echo 4. open the developer console
|
echo 4. merge models (diffusers type only)
|
||||||
echo 5. re-run the configure script to download new models
|
echo 5. open the developer console
|
||||||
|
echo 6. re-run the configure script to download new models
|
||||||
set /P restore="Please enter 1, 2, 3, 4 or 5: [5] "
|
set /P restore="Please enter 1, 2, 3, 4 or 5: [5] "
|
||||||
if not defined restore set restore=2
|
if not defined restore set restore=2
|
||||||
IF /I "%restore%" == "1" (
|
IF /I "%restore%" == "1" (
|
||||||
@ -24,6 +25,9 @@ IF /I "%restore%" == "1" (
|
|||||||
echo Starting textual inversion training..
|
echo Starting textual inversion training..
|
||||||
python .venv\Scripts\textual_inversion_fe.py --web %*
|
python .venv\Scripts\textual_inversion_fe.py --web %*
|
||||||
) ELSE IF /I "%restore%" == "4" (
|
) ELSE IF /I "%restore%" == "4" (
|
||||||
|
echo Starting model merging script..
|
||||||
|
python .venv\Scripts\merge_models_fe.py --web %*
|
||||||
|
) ELSE IF /I "%restore%" == "5" (
|
||||||
echo Developer Console
|
echo Developer Console
|
||||||
echo Python command is:
|
echo Python command is:
|
||||||
where python
|
where python
|
||||||
@ -35,7 +39,7 @@ IF /I "%restore%" == "1" (
|
|||||||
echo *************************
|
echo *************************
|
||||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||||
call cmd /k
|
call cmd /k
|
||||||
) ELSE IF /I "%restore%" == "5" (
|
) ELSE IF /I "%restore%" == "6" (
|
||||||
echo Running configure_invokeai.py...
|
echo Running configure_invokeai.py...
|
||||||
python .venv\Scripts\configure_invokeai.py --web %*
|
python .venv\Scripts\configure_invokeai.py --web %*
|
||||||
) ELSE (
|
) ELSE (
|
||||||
|
@ -20,16 +20,18 @@ if [ "$0" != "bash" ]; then
|
|||||||
echo "1. command-line"
|
echo "1. command-line"
|
||||||
echo "2. browser-based UI"
|
echo "2. browser-based UI"
|
||||||
echo "3. run textual inversion training"
|
echo "3. run textual inversion training"
|
||||||
echo "4. open the developer console"
|
echo "4. merge models (diffusers type only)"
|
||||||
echo "5. re-run the configure script to download new models"
|
echo "5. re-run the configure script to download new models"
|
||||||
|
echo "6. open the developer console"
|
||||||
read -p "Please enter 1, 2, 3, 4 or 5: [1] " yn
|
read -p "Please enter 1, 2, 3, 4 or 5: [1] " yn
|
||||||
choice=${yn:='2'}
|
choice=${yn:='2'}
|
||||||
case $choice in
|
case $choice in
|
||||||
1 ) printf "\nStarting the InvokeAI command-line..\n"; .venv/bin/python .venv/bin/invoke.py $*;;
|
1 ) printf "\nStarting the InvokeAI command-line..\n"; .venv/bin/python .venv/bin/invoke.py $*;;
|
||||||
2 ) printf "\nStarting the InvokeAI browser-based UI..\n"; .venv/bin/python .venv/bin/invoke.py --web $*;;
|
2 ) printf "\nStarting the InvokeAI browser-based UI..\n"; .venv/bin/python .venv/bin/invoke.py --web $*;;
|
||||||
3 ) printf "\nStarting Textual Inversion:\n"; .venv/bin/python .venv/bin/textual_inversion_fe.py $*;;
|
3 ) printf "\nStarting Textual Inversion:\n"; .venv/bin/python .venv/bin/textual_inversion_fe.py $*;;
|
||||||
4 ) printf "\nDeveloper Console:\n"; file_name=$(basename "${BASH_SOURCE[0]}"); bash --init-file "$file_name";;
|
4 ) printf "\nMerging Models:\n"; .venv/bin/python .venv/bin/merge_models_fe.py $*;;
|
||||||
5 ) printf "\nRunning configure_invokeai.py:\n"; .venv/bin/python .venv/bin/configure_invokeai.py $*;;
|
5 ) printf "\nDeveloper Console:\n"; file_name=$(basename "${BASH_SOURCE[0]}"); bash --init-file "$file_name";;
|
||||||
|
6 ) printf "\nRunning configure_invokeai.py:\n"; .venv/bin/python .venv/bin/configure_invokeai.py $*;;
|
||||||
* ) echo "Invalid selection"; exit;;
|
* ) echo "Invalid selection"; exit;;
|
||||||
esac
|
esac
|
||||||
else # in developer console
|
else # in developer console
|
||||||
|
@ -29,6 +29,7 @@ else:
|
|||||||
|
|
||||||
# Where to look for the initialization file
|
# Where to look for the initialization file
|
||||||
Globals.initfile = 'invokeai.init'
|
Globals.initfile = 'invokeai.init'
|
||||||
|
Globals.models_file = 'models.yaml'
|
||||||
Globals.models_dir = 'models'
|
Globals.models_dir = 'models'
|
||||||
Globals.config_dir = 'configs'
|
Globals.config_dir = 'configs'
|
||||||
Globals.autoscan_dir = 'weights'
|
Globals.autoscan_dir = 'weights'
|
||||||
@ -49,6 +50,9 @@ Globals.disable_xformers = False
|
|||||||
# whether we are forcing full precision
|
# whether we are forcing full precision
|
||||||
Globals.full_precision = False
|
Globals.full_precision = False
|
||||||
|
|
||||||
|
def global_config_file()->Path:
|
||||||
|
return Path(Globals.root, Globals.config_dir, Globals.models_file)
|
||||||
|
|
||||||
def global_config_dir()->Path:
|
def global_config_dir()->Path:
|
||||||
return Path(Globals.root, Globals.config_dir)
|
return Path(Globals.root, Globals.config_dir)
|
||||||
|
|
||||||
|
62
ldm/invoke/merge_diffusers.py
Normal file
62
ldm/invoke/merge_diffusers.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
'''
|
||||||
|
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
|
||||||
|
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
|
||||||
|
'''
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
from ldm.invoke.globals import global_config_file, global_models_dir, global_cache_dir
|
||||||
|
from ldm.invoke.model_manager import ModelManager
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
def merge_diffusion_models(models:List['str'],
|
||||||
|
merged_model_name:str,
|
||||||
|
alpha:float=0.5,
|
||||||
|
interp:str=None,
|
||||||
|
force:bool=False,
|
||||||
|
**kwargs):
|
||||||
|
'''
|
||||||
|
models - up to three models, designated by their InvokeAI models.yaml model name
|
||||||
|
merged_model_name = name for new model
|
||||||
|
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||||
|
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||||
|
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||||
|
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||||
|
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||||
|
|
||||||
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
|
'''
|
||||||
|
config_file = global_config_file()
|
||||||
|
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||||
|
for mod in models:
|
||||||
|
assert (mod in model_manager.model_names()), f'** Unknown model "{mod}"'
|
||||||
|
assert (model_manager.model_info(mod).get('format',None) == 'diffusers'), f'** {mod} is not a diffusers model. It must be optimized before merging.'
|
||||||
|
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
||||||
|
|
||||||
|
pipe = DiffusionPipeline.from_pretrained(model_ids_or_paths[0],
|
||||||
|
cache_dir=kwargs.get('cache_dir',global_cache_dir()),
|
||||||
|
custom_pipeline='checkpoint_merger')
|
||||||
|
merged_pipe = pipe.merge(pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||||
|
alpha=alpha,
|
||||||
|
interp=interp,
|
||||||
|
force=force,
|
||||||
|
**kwargs)
|
||||||
|
dump_path = global_models_dir() / 'merged_diffusers'
|
||||||
|
os.makedirs(dump_path,exist_ok=True)
|
||||||
|
dump_path = dump_path / merged_model_name
|
||||||
|
merged_pipe.save_pretrained (
|
||||||
|
dump_path,
|
||||||
|
safe_serialization=1
|
||||||
|
)
|
||||||
|
model_manager.import_diffuser_model(
|
||||||
|
dump_path,
|
||||||
|
model_name = merged_model_name,
|
||||||
|
description = f'Merge of models {", ".join(models)}'
|
||||||
|
)
|
||||||
|
print('REMINDER: When PR 2369 is merged, replace merge_diffusers.py line 56 with vae= argument to impormodel()')
|
||||||
|
if vae := model_manager.config[models[0]].get('vae',None):
|
||||||
|
print(f'>> Using configured VAE assigned to {models[0]}')
|
||||||
|
model_manager.config[merged_model_name]['vae'] = vae
|
||||||
|
|
||||||
|
model_manager.commit(config_file)
|
@ -37,7 +37,11 @@ from ldm.util import instantiate_from_config, ask_user
|
|||||||
DEFAULT_MAX_MODELS=2
|
DEFAULT_MAX_MODELS=2
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
|
def __init__(self,
|
||||||
|
config:OmegaConf,
|
||||||
|
device_type:str='cpu',
|
||||||
|
precision:str='float16',
|
||||||
|
max_loaded_models=DEFAULT_MAX_MODELS):
|
||||||
'''
|
'''
|
||||||
Initialize with the path to the models.yaml config file,
|
Initialize with the path to the models.yaml config file,
|
||||||
the torch device type, and precision. The optional
|
the torch device type, and precision. The optional
|
||||||
@ -536,7 +540,7 @@ class ModelManager(object):
|
|||||||
format='diffusers',
|
format='diffusers',
|
||||||
)
|
)
|
||||||
if isinstance(repo_or_path,Path) and repo_or_path.exists():
|
if isinstance(repo_or_path,Path) and repo_or_path.exists():
|
||||||
new_config.update(path=repo_or_path)
|
new_config.update(path=str(repo_or_path))
|
||||||
else:
|
else:
|
||||||
new_config.update(repo_id=repo_or_path)
|
new_config.update(repo_id=repo_or_path)
|
||||||
|
|
||||||
|
0
scripts/load_models.py
Normal file → Executable file
0
scripts/load_models.py
Normal file → Executable file
92
scripts/merge_models.py
Executable file
92
scripts/merge_models.py
Executable file
@ -0,0 +1,92 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from ldm.invoke.globals import (Globals, global_cache_dir, global_config_file,
|
||||||
|
global_set_root)
|
||||||
|
from ldm.invoke.model_manager import ModelManager
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="InvokeAI textual inversion training")
|
||||||
|
parser.add_argument(
|
||||||
|
"--root_dir",
|
||||||
|
"--root-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Globals.root,
|
||||||
|
help="Path to the invokeai runtime directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="Two to three model names to be merged",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--merged_model_name",
|
||||||
|
"--destination",
|
||||||
|
dest="merged_model_name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--alpha",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--interpolation",
|
||||||
|
dest="interp",
|
||||||
|
type=str,
|
||||||
|
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
|
||||||
|
default="weighted_sum",
|
||||||
|
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--force",
|
||||||
|
action="store_true",
|
||||||
|
help="Try to merge models even if they are incompatible with each other",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--clobber",
|
||||||
|
"--overwrite",
|
||||||
|
dest='clobber',
|
||||||
|
action="store_true",
|
||||||
|
help="Overwrite the merged model if --merged_model_name already exists",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
global_set_root(args.root_dir)
|
||||||
|
|
||||||
|
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||||
|
assert len(args.models) >= 1 and len(args.models) <= 3, "provide 2 or 3 models to merge"
|
||||||
|
|
||||||
|
if not args.merged_model_name:
|
||||||
|
args.merged_model_name = "+".join(args.models)
|
||||||
|
print(
|
||||||
|
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||||
|
)
|
||||||
|
|
||||||
|
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
||||||
|
assert (args.clobber or args.merged_model_name not in model_manager.model_names()), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||||
|
|
||||||
|
# It seems that the merge pipeline is not honoring cache_dir, so we set the
|
||||||
|
# HF_HOME environment variable here *before* we load diffusers.
|
||||||
|
cache_dir = str(global_cache_dir("diffusers"))
|
||||||
|
os.environ["HF_HOME"] = cache_dir
|
||||||
|
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||||
|
|
||||||
|
try:
|
||||||
|
merge_diffusion_models(**vars(args))
|
||||||
|
print(f'>> Models merged into new model: "{args.merged_model_name}".')
|
||||||
|
except Exception as e:
|
||||||
|
print(f"** An error occurred while merging the pipelines: {str(e)}")
|
||||||
|
print("** DETAILS:")
|
||||||
|
print(traceback.format_exc())
|
||||||
|
sys.exit(-1)
|
87
scripts/merge_fe.py → scripts/merge_models_fe.py
Normal file → Executable file
87
scripts/merge_fe.py → scripts/merge_models_fe.py
Normal file → Executable file
@ -3,11 +3,10 @@
|
|||||||
import npyscreen
|
import npyscreen
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import re
|
|
||||||
import shutil
|
|
||||||
import traceback
|
import traceback
|
||||||
import argparse
|
import argparse
|
||||||
from ldm.invoke.globals import Globals, global_set_root
|
from ldm.invoke.globals import Globals, global_set_root, global_cache_dir, global_config_file
|
||||||
|
from ldm.invoke.model_manager import ModelManager
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
@ -30,6 +29,14 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
'inv_sigmoid',
|
'inv_sigmoid',
|
||||||
'add_difference']
|
'add_difference']
|
||||||
|
|
||||||
|
def __init__(self, parentApp, name):
|
||||||
|
self.parentApp = parentApp
|
||||||
|
super().__init__(parentApp, name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_manager(self):
|
||||||
|
return self.parentApp.model_manager
|
||||||
|
|
||||||
def afterEditing(self):
|
def afterEditing(self):
|
||||||
self.parentApp.setNextForm(None)
|
self.parentApp.setNextForm(None)
|
||||||
|
|
||||||
@ -83,6 +90,11 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
lowest=0,
|
lowest=0,
|
||||||
value=0.5,
|
value=0.5,
|
||||||
)
|
)
|
||||||
|
self.force = self.add_widget_intelligent(
|
||||||
|
npyscreen.Checkbox,
|
||||||
|
name='Force merge of incompatible models',
|
||||||
|
value=False,
|
||||||
|
)
|
||||||
self.merged_model_name = self.add_widget_intelligent(
|
self.merged_model_name = self.add_widget_intelligent(
|
||||||
npyscreen.TitleText,
|
npyscreen.TitleText,
|
||||||
name='Name for merged model',
|
name='Name for merged model',
|
||||||
@ -105,20 +117,51 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
self.merge_method.value=0
|
self.merge_method.value=0
|
||||||
|
|
||||||
def on_ok(self):
|
def on_ok(self):
|
||||||
if self.validate_field_values():
|
if self.validate_field_values() and self.check_for_overwrite():
|
||||||
self.parentApp.setNextForm(None)
|
self.parentApp.setNextForm(None)
|
||||||
self.editing = False
|
self.editing = False
|
||||||
|
self.parentApp.merge_arguments = self.marshall_arguments()
|
||||||
|
npyscreen.notify('Starting the merge...')
|
||||||
|
import ldm.invoke.merge_diffusers # this keeps the message up while diffusers loads
|
||||||
else:
|
else:
|
||||||
self.editing = True
|
self.editing = True
|
||||||
|
|
||||||
def ok_cancel(self):
|
def on_cancel(self):
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
def marshall_arguments(self)->dict:
|
||||||
|
model_names = self.model_names
|
||||||
|
models = [
|
||||||
|
model_names[self.model1.value[0]],
|
||||||
|
model_names[self.model2.value[0]],
|
||||||
|
]
|
||||||
|
if self.model3.value[0] > 0:
|
||||||
|
models.append(model_names[self.model3.value[0]-1])
|
||||||
|
|
||||||
|
args = dict(
|
||||||
|
models=models,
|
||||||
|
alpha = self.alpha.value,
|
||||||
|
interp = self.interpolations[self.merge_method.value[0]],
|
||||||
|
force = self.force.value,
|
||||||
|
merged_model_name = self.merged_model_name.value,
|
||||||
|
)
|
||||||
|
return args
|
||||||
|
|
||||||
|
def check_for_overwrite(self)->bool:
|
||||||
|
model_out = self.merged_model_name.value
|
||||||
|
if model_out not in self.model_names:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return npyscreen.notify_yes_no(f'The chosen merged model destination, {model_out}, is already in use. Overwrite?')
|
||||||
|
|
||||||
def validate_field_values(self)->bool:
|
def validate_field_values(self)->bool:
|
||||||
bad_fields = []
|
bad_fields = []
|
||||||
selected_models = set((self.model1.value[0],self.model2.value[0],self.model3.value[0]))
|
model_names = self.model_names
|
||||||
if len(selected_models) < 3:
|
selected_models = set((model_names[self.model1.value[0]],model_names[self.model2.value[0]]))
|
||||||
bad_fields.append('Please select two or three DIFFERENT models to compare')
|
if self.model3.value[0] > 0:
|
||||||
|
selected_models.add(model_names[self.model3.value[0]-1])
|
||||||
|
if len(selected_models) < 2:
|
||||||
|
bad_fields.append(f'Please select two or three DIFFERENT models to compare. You selected {selected_models}')
|
||||||
if len(bad_fields) > 0:
|
if len(bad_fields) > 0:
|
||||||
message = 'The following problems were detected and must be corrected:'
|
message = 'The following problems were detected and must be corrected:'
|
||||||
for problem in bad_fields:
|
for problem in bad_fields:
|
||||||
@ -129,13 +172,15 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def get_model_names(self)->List[str]:
|
def get_model_names(self)->List[str]:
|
||||||
conf = OmegaConf.load(os.path.join(Globals.root,'configs/models.yaml'))
|
model_names = [name for name in self.model_manager.model_names() if self.model_manager.model_info(name).get('format') == 'diffusers']
|
||||||
model_names = [name for name in conf.keys() if conf[name].get('format',None)=='diffusers']
|
print(model_names)
|
||||||
return sorted(model_names)
|
return sorted(model_names)
|
||||||
|
|
||||||
class MyApplication(npyscreen.NPSAppManaged):
|
class Mergeapp(npyscreen.NPSAppManaged):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
conf = OmegaConf.load(global_config_file())
|
||||||
|
self.model_manager = ModelManager(conf,'cpu','float16') # precision doesn't really matter here
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||||
@ -152,5 +197,21 @@ if __name__ == '__main__':
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
global_set_root(args.root_dir)
|
global_set_root(args.root_dir)
|
||||||
|
|
||||||
myapplication = MyApplication()
|
cache_dir = str(global_cache_dir('diffusers')) # because not clear the merge pipeline is honoring cache_dir
|
||||||
myapplication.run()
|
os.environ['HF_HOME'] = cache_dir
|
||||||
|
|
||||||
|
mergeapp = Mergeapp()
|
||||||
|
mergeapp.run()
|
||||||
|
|
||||||
|
args = mergeapp.merge_arguments
|
||||||
|
args.update(cache_dir = cache_dir)
|
||||||
|
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||||
|
|
||||||
|
try:
|
||||||
|
merge_diffusion_models(**args)
|
||||||
|
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
|
||||||
|
except Exception as e:
|
||||||
|
print(f'** An error occurred while merging the pipelines: {str(e)}')
|
||||||
|
print('** DETAILS:')
|
||||||
|
print(traceback.format_exc())
|
||||||
|
sys.exit(-1)
|
0
scripts/merge_embeddings.py → scripts/orig_scripts/merge_embeddings.py
Normal file → Executable file
0
scripts/merge_embeddings.py → scripts/orig_scripts/merge_embeddings.py
Normal file → Executable file
5
setup.py
5
setup.py
@ -92,8 +92,9 @@ setup(
|
|||||||
'Topic :: Scientific/Engineering :: Image Processing',
|
'Topic :: Scientific/Engineering :: Image Processing',
|
||||||
],
|
],
|
||||||
scripts = ['scripts/invoke.py','scripts/configure_invokeai.py', 'scripts/sd-metadata.py',
|
scripts = ['scripts/invoke.py','scripts/configure_invokeai.py', 'scripts/sd-metadata.py',
|
||||||
'scripts/preload_models.py', 'scripts/images2prompt.py','scripts/merge_embeddings.py',
|
'scripts/preload_models.py', 'scripts/images2prompt.py',
|
||||||
'scripts/textual_inversion_fe.py','scripts/textual_inversion.py'
|
'scripts/textual_inversion_fe.py','scripts/textual_inversion.py',
|
||||||
|
'scripts/merge_models_fe.py', 'scripts/merge_models.py',
|
||||||
],
|
],
|
||||||
data_files=FRONTEND_FILES,
|
data_files=FRONTEND_FILES,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user