[WebUI] Model Merging (#2699)

This PR brings Model Merging to the WebUI.

Inside the Model Manager, you can now find a new button called Merge
Models. Rest of it is self explanatory.


![firefox_BYCM4YNHEa](https://user-images.githubusercontent.com/54517381/219795631-dbb5c5c4-fc3a-4cdd-9549-18c2e5302835.png)
This commit is contained in:
blessedcoolant 2023-02-18 14:34:35 +13:00 committed by GitHub
commit 767012aec0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1240 additions and 769 deletions

View File

@ -31,6 +31,8 @@ from ldm.invoke.generator.inpaint import infill_methods
from ldm.invoke.globals import Globals, global_converted_ckpts_dir from ldm.invoke.globals import Globals, global_converted_ckpts_dir
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend
from ldm.invoke.globals import global_models_dir
from ldm.invoke.merge_diffusers import merge_diffusion_models
# Loading Arguments # Loading Arguments
opt = Args() opt = Args()
@ -205,11 +207,7 @@ class InvokeAIWebServer:
return make_response(response, 200) return make_response(response, 200)
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
return make_response("Error uploading file", 500) return make_response("Error uploading file", 500)
self.load_socketio_listeners(self.socketio) self.load_socketio_listeners(self.socketio)
@ -317,10 +315,7 @@ class InvokeAIWebServer:
'found_models': found_models}, 'found_models': found_models},
) )
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n") print("\n")
@socketio.on("addNewModel") @socketio.on("addNewModel")
@ -350,11 +345,7 @@ class InvokeAIWebServer:
) )
print(f">> New Model Added: {model_name}") print(f">> New Model Added: {model_name}")
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
@socketio.on("deleteModel") @socketio.on("deleteModel")
def handle_delete_model(model_name: str): def handle_delete_model(model_name: str):
@ -370,11 +361,7 @@ class InvokeAIWebServer:
) )
print(f">> Model Deleted: {model_name}") print(f">> Model Deleted: {model_name}")
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
@socketio.on("requestModelChange") @socketio.on("requestModelChange")
def handle_set_model(model_name: str): def handle_set_model(model_name: str):
@ -393,11 +380,7 @@ class InvokeAIWebServer:
{"model_name": model_name, "model_list": model_list}, {"model_name": model_name, "model_list": model_list},
) )
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
@socketio.on('convertToDiffusers') @socketio.on('convertToDiffusers')
def convert_to_diffusers(model_to_convert: dict): def convert_to_diffusers(model_to_convert: dict):
@ -428,10 +411,12 @@ class InvokeAIWebServer:
) )
if model_to_convert['save_location'] == 'root': if model_to_convert['save_location'] == 'root':
diffusers_path = Path(global_converted_ckpts_dir(), f'{model_name}_diffusers') diffusers_path = Path(
global_converted_ckpts_dir(), f'{model_name}_diffusers')
if model_to_convert['save_location'] == 'custom' and model_to_convert['custom_location'] is not None: if model_to_convert['save_location'] == 'custom' and model_to_convert['custom_location'] is not None:
diffusers_path = Path(model_to_convert['custom_location'], f'{model_name}_diffusers') diffusers_path = Path(
model_to_convert['custom_location'], f'{model_name}_diffusers')
if diffusers_path.exists(): if diffusers_path.exists():
shutil.rmtree(diffusers_path) shutil.rmtree(diffusers_path)
@ -454,11 +439,51 @@ class InvokeAIWebServer:
) )
print(f">> Model Converted: {model_name}") print(f">> Model Converted: {model_name}")
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc() @socketio.on('mergeDiffusersModels')
print("\n") def merge_diffusers_models(model_merge_info: dict):
try:
models_to_merge = model_merge_info['models_to_merge']
model_ids_or_paths = [
self.generate.model_manager.model_name_or_path(x) for x in models_to_merge]
merged_pipe = merge_diffusion_models(
model_ids_or_paths, model_merge_info['alpha'], model_merge_info['interp'], model_merge_info['force'])
dump_path = global_models_dir() / 'merged_models'
if model_merge_info['model_merge_save_path'] is not None:
dump_path = Path(model_merge_info['model_merge_save_path'])
os.makedirs(dump_path, exist_ok=True)
dump_path = dump_path / model_merge_info['merged_model_name']
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
merged_model_config = dict(
model_name=model_merge_info['merged_model_name'],
description=f'Merge of models {", ".join(models_to_merge)}',
commit_to_conf=opt.conf
)
if vae := self.generate.model_manager.config[models_to_merge[0]].get("vae", None):
print(
f">> Using configured VAE assigned to {models_to_merge[0]}")
merged_model_config.update(vae=vae)
self.generate.model_manager.import_diffuser_model(
dump_path, **merged_model_config)
new_model_list = self.generate.model_manager.list_models()
socketio.emit(
"modelsMerged",
{"merged_models": models_to_merge,
"merged_model_name": model_merge_info['merged_model_name'],
"model_list": new_model_list, 'update': True},
)
print(f">> Models Merged: {models_to_merge}")
print(
f">> New Model Added: {model_merge_info['merged_model_name']}")
except Exception as e:
self.handle_exceptions(e)
@socketio.on("requestEmptyTempFolder") @socketio.on("requestEmptyTempFolder")
def empty_temp_folder(): def empty_temp_folder():
@ -479,11 +504,7 @@ class InvokeAIWebServer:
socketio.emit("tempFolderEmptied") socketio.emit("tempFolderEmptied")
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
@socketio.on("requestSaveStagingAreaImageToGallery") @socketio.on("requestSaveStagingAreaImageToGallery")
def save_temp_image_to_gallery(url): def save_temp_image_to_gallery(url):
@ -525,11 +546,7 @@ class InvokeAIWebServer:
) )
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
@socketio.on("requestLatestImages") @socketio.on("requestLatestImages")
def handle_request_latest_images(category, latest_mtime): def handle_request_latest_images(category, latest_mtime):
@ -595,11 +612,7 @@ class InvokeAIWebServer:
{"images": image_array, "category": category}, {"images": image_array, "category": category},
) )
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
@socketio.on("requestImages") @socketio.on("requestImages")
def handle_request_images(category, earliest_mtime=None): def handle_request_images(category, earliest_mtime=None):
@ -674,11 +687,7 @@ class InvokeAIWebServer:
}, },
) )
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
@socketio.on("generateImage") @socketio.on("generateImage")
def handle_generate_image_event( def handle_generate_image_event(
@ -711,11 +720,7 @@ class InvokeAIWebServer:
facetool_parameters, facetool_parameters,
) )
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
@socketio.on("runPostprocessing") @socketio.on("runPostprocessing")
def handle_run_postprocessing(original_image, postprocessing_parameters): def handle_run_postprocessing(original_image, postprocessing_parameters):
@ -829,11 +834,7 @@ class InvokeAIWebServer:
}, },
) )
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
@socketio.on("cancel") @socketio.on("cancel")
def handle_cancel(): def handle_cancel():
@ -858,11 +859,7 @@ class InvokeAIWebServer:
{"url": url, "uuid": uuid, "category": category}, {"url": url, "uuid": uuid, "category": category},
) )
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
# App Functions # App Functions
def get_system_config(self): def get_system_config(self):
@ -1312,11 +1309,7 @@ class InvokeAIWebServer:
# Clear the CUDA cache on an exception # Clear the CUDA cache on an exception
self.empty_cuda_cache() self.empty_cuda_cache()
print(e) print(e)
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
def empty_cuda_cache(self): def empty_cuda_cache(self):
if self.generate.device.type == "cuda": if self.generate.device.type == "cuda":
@ -1423,11 +1416,7 @@ class InvokeAIWebServer:
return metadata return metadata
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
def parameters_to_post_processed_image_metadata( def parameters_to_post_processed_image_metadata(
self, parameters, original_image_path self, parameters, original_image_path
@ -1480,11 +1469,7 @@ class InvokeAIWebServer:
return current_metadata return current_metadata
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
def save_result_image( def save_result_image(
self, self,
@ -1528,11 +1513,7 @@ class InvokeAIWebServer:
return os.path.abspath(path) return os.path.abspath(path)
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
def make_unique_init_image_filename(self, name): def make_unique_init_image_filename(self, name):
try: try:
@ -1541,11 +1522,7 @@ class InvokeAIWebServer:
name = f"{split[0]}.{uuid}{split[1]}" name = f"{split[0]}.{uuid}{split[1]}"
return name return name
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
def calculate_real_steps(self, steps, strength, has_init_image): def calculate_real_steps(self, steps, strength, has_init_image):
import math import math
@ -1560,11 +1537,7 @@ class InvokeAIWebServer:
file.writelines(message) file.writelines(message)
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
def get_image_path_from_url(self, url): def get_image_path_from_url(self, url):
"""Given a url to an image used by the client, returns the absolute file path to that image""" """Given a url to an image used by the client, returns the absolute file path to that image"""
@ -1595,11 +1568,7 @@ class InvokeAIWebServer:
os.path.join(self.result_path, os.path.basename(url)) os.path.join(self.result_path, os.path.basename(url))
) )
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
def get_url_from_image_path(self, path): def get_url_from_image_path(self, path):
"""Given an absolute file path to an image, returns the URL that the client can use to load the image""" """Given an absolute file path to an image, returns the URL that the client can use to load the image"""
@ -1617,11 +1586,7 @@ class InvokeAIWebServer:
else: else:
return os.path.join(self.result_url, os.path.basename(path)) return os.path.join(self.result_url, os.path.basename(path))
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc()
print("\n")
def save_file_unique_uuid_name(self, bytes, name, path): def save_file_unique_uuid_name(self, bytes, name, path):
try: try:
@ -1640,11 +1605,13 @@ class InvokeAIWebServer:
return file_path return file_path
except Exception as e: except Exception as e:
self.socketio.emit("error", {"message": (str(e))}) self.handle_exceptions(e)
print("\n")
traceback.print_exc() def handle_exceptions(self, exception, emit_key: str = 'error'):
print("\n") self.socketio.emit(emit_key, {"message": (str(exception))})
print("\n")
traceback.print_exc()
print("\n")
class Progress: class Progress:

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>InvokeAI - A Stable Diffusion Toolkit</title> <title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" /> <link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
<script type="module" crossorigin src="./assets/index-9237ac63.js"></script> <script type="module" crossorigin src="./assets/index-7062a172.js"></script>
<link rel="stylesheet" href="./assets/index-14cb2922.css"> <link rel="stylesheet" href="./assets/index-14cb2922.css">
</head> </head>

View File

@ -9,6 +9,19 @@
"darkTheme": "Dark", "darkTheme": "Dark",
"lightTheme": "Light", "lightTheme": "Light",
"greenTheme": "Green", "greenTheme": "Green",
"langArabic": "العربية",
"langEnglish": "English",
"langDutch": "Nederlands",
"langFrench": "Français",
"langGerman": "Deutsch",
"langItalian": "Italiano",
"langJapanese": "日本語",
"langPolish": "Polski",
"langBrPortuguese": "Português do Brasil",
"langRussian": "Русский",
"langSimplifiedChinese": "简体中文",
"langUkranian": "Украї́нська",
"langSpanish": "Español",
"text2img": "Text To Image", "text2img": "Text To Image",
"img2img": "Image To Image", "img2img": "Image To Image",
"unifiedCanvas": "Unified Canvas", "unifiedCanvas": "Unified Canvas",
@ -45,5 +58,9 @@
"statusUpscaling": "Upscaling", "statusUpscaling": "Upscaling",
"statusUpscalingESRGAN": "Upscaling (ESRGAN)", "statusUpscalingESRGAN": "Upscaling (ESRGAN)",
"statusLoadingModel": "Loading Model", "statusLoadingModel": "Loading Model",
"statusModelChanged": "Model Changed" "statusModelChanged": "Model Changed",
"statusConvertingModel": "Converting Model",
"statusModelConverted": "Model Converted",
"statusMergingModels": "Merging Models",
"statusMergedModels": "Models Merged"
} }

View File

@ -60,5 +60,7 @@
"statusLoadingModel": "Loading Model", "statusLoadingModel": "Loading Model",
"statusModelChanged": "Model Changed", "statusModelChanged": "Model Changed",
"statusConvertingModel": "Converting Model", "statusConvertingModel": "Converting Model",
"statusModelConverted": "Model Converted" "statusModelConverted": "Model Converted",
"statusMergingModels": "Merging Models",
"statusMergedModels": "Models Merged"
} }

View File

@ -22,7 +22,7 @@
"config": "Config", "config": "Config",
"configValidationMsg": "Path to the config file of your model.", "configValidationMsg": "Path to the config file of your model.",
"modelLocation": "Model Location", "modelLocation": "Model Location",
"modelLocationValidationMsg": "Path to where your model is located.", "modelLocationValidationMsg": "Path to where your model is located locally.",
"repo_id": "Repo ID", "repo_id": "Repo ID",
"repoIDValidationMsg": "Online repository of your model", "repoIDValidationMsg": "Online repository of your model",
"vaeLocation": "VAE Location", "vaeLocation": "VAE Location",
@ -72,14 +72,33 @@
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.", "convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.", "convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?", "convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location",
"v1": "v1", "v1": "v1",
"v2": "v2", "v2": "v2",
"inpainting": "v1 Inpainting", "inpainting": "v1 Inpainting",
"customConfig": "Custom Config", "customConfig": "Custom Config",
"pathToCustomConfig": "Path To Custom Config", "pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting", "statusConverting": "Converting",
"sameFolder": "Same Folder", "modelConverted": "Model Converted",
"invokeRoot": "Invoke Models", "sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder",
"custom": "Custom", "custom": "Custom",
"customSaveLocation": "Custom Save Location" "customSaveLocation": "Custom Save Location",
"merge": "Merge",
"modelsMerged": "Models Merged",
"mergeModels": "Merge Models",
"modelOne": "Model 1",
"modelTwo": "Model 2",
"modelThree": "Model 3",
"mergedModelName": "Merged Model Name",
"alpha": "Alpha",
"interpolationType": "Interpolation Type",
"mergedModelSaveLocation": "Save Location",
"mergedModelCustomSaveLocation": "Custom Path",
"invokeAIFolder": "Invoke AI Folder",
"ignoreMismatch": "Ignore Mismatches Between Selected Models",
"modelMergeHeaderHelp1": "You can merge upto three different models to create a blend that suits your needs.",
"modelMergeHeaderHelp2": "Only Diffusers are available for merging. If you want to merge a checkpoint model, please convert it to Diffusers first.",
"modelMergeAlphaHelp": "Alpha controls blend strength for the models. Lower alpha values lead to lower influence of the second model.",
"modelMergeInterpAddDifferenceHelp": "In this mode, Model 3 is first subtracted from Model 2. The resulting version is blended with Model 1 with the alpha rate set above."
} }

View File

@ -76,12 +76,29 @@
"v1": "v1", "v1": "v1",
"v2": "v2", "v2": "v2",
"inpainting": "v1 Inpainting", "inpainting": "v1 Inpainting",
"customConfig": "Custom Config", "customConfig": "Custom Config",
"pathToCustomConfig": "Path To Custom Config", "pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting", "statusConverting": "Converting",
"modelConverted": "Model Converted", "modelConverted": "Model Converted",
"sameFolder": "Same folder", "sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder", "invokeRoot": "InvokeAI folder",
"custom": "Custom", "custom": "Custom",
"customSaveLocation": "Custom Save Location" "customSaveLocation": "Custom Save Location",
"merge": "Merge",
"modelsMerged": "Models Merged",
"mergeModels": "Merge Models",
"modelOne": "Model 1",
"modelTwo": "Model 2",
"modelThree": "Model 3",
"mergedModelName": "Merged Model Name",
"alpha": "Alpha",
"interpolationType": "Interpolation Type",
"mergedModelSaveLocation": "Save Location",
"mergedModelCustomSaveLocation": "Custom Path",
"invokeAIFolder": "Invoke AI Folder",
"ignoreMismatch": "Ignore Mismatches Between Selected Models",
"modelMergeHeaderHelp1": "You can merge upto three different models to create a blend that suits your needs.",
"modelMergeHeaderHelp2": "Only Diffusers are available for merging. If you want to merge a checkpoint model, please convert it to Diffusers first.",
"modelMergeAlphaHelp": "Alpha controls blend strength for the models. Lower alpha values lead to lower influence of the second model.",
"modelMergeInterpAddDifferenceHelp": "In this mode, Model 3 is first subtracted from Model 2. The resulting version is blended with Model 1 with the alpha rate set above."
} }

View File

@ -22,6 +22,7 @@
"upscaling": "Upscaling", "upscaling": "Upscaling",
"upscale": "Upscale", "upscale": "Upscale",
"upscaleImage": "Upscale Image", "upscaleImage": "Upscale Image",
"denoisingStrength": "Denoising Strength",
"scale": "Scale", "scale": "Scale",
"otherOptions": "Other Options", "otherOptions": "Other Options",
"seamlessTiling": "Seamless Tiling", "seamlessTiling": "Seamless Tiling",
@ -46,9 +47,11 @@
"invoke": "Invoke", "invoke": "Invoke",
"cancel": "Cancel", "cancel": "Cancel",
"promptPlaceholder": "Type prompt here. [negative tokens], (upweight)++, (downweight)--, swap and blend are available (see docs)", "promptPlaceholder": "Type prompt here. [negative tokens], (upweight)++, (downweight)--, swap and blend are available (see docs)",
"negativePrompts": "Negative Prompts",
"sendTo": "Send to", "sendTo": "Send to",
"sendToImg2Img": "Send to Image to Image", "sendToImg2Img": "Send to Image to Image",
"sendToUnifiedCanvas": "Send To Unified Canvas", "sendToUnifiedCanvas": "Send To Unified Canvas",
"copyImage": "Copy Image",
"copyImageToLink": "Copy Image To Link", "copyImageToLink": "Copy Image To Link",
"downloadImage": "Download Image", "downloadImage": "Download Image",
"openInViewer": "Open In Viewer", "openInViewer": "Open In Viewer",

View File

@ -9,6 +9,19 @@
"darkTheme": "Dark", "darkTheme": "Dark",
"lightTheme": "Light", "lightTheme": "Light",
"greenTheme": "Green", "greenTheme": "Green",
"langArabic": "العربية",
"langEnglish": "English",
"langDutch": "Nederlands",
"langFrench": "Français",
"langGerman": "Deutsch",
"langItalian": "Italiano",
"langJapanese": "日本語",
"langPolish": "Polski",
"langBrPortuguese": "Português do Brasil",
"langRussian": "Русский",
"langSimplifiedChinese": "简体中文",
"langUkranian": "Украї́нська",
"langSpanish": "Español",
"text2img": "Text To Image", "text2img": "Text To Image",
"img2img": "Image To Image", "img2img": "Image To Image",
"unifiedCanvas": "Unified Canvas", "unifiedCanvas": "Unified Canvas",
@ -45,5 +58,9 @@
"statusUpscaling": "Upscaling", "statusUpscaling": "Upscaling",
"statusUpscalingESRGAN": "Upscaling (ESRGAN)", "statusUpscalingESRGAN": "Upscaling (ESRGAN)",
"statusLoadingModel": "Loading Model", "statusLoadingModel": "Loading Model",
"statusModelChanged": "Model Changed" "statusModelChanged": "Model Changed",
"statusConvertingModel": "Converting Model",
"statusModelConverted": "Model Converted",
"statusMergingModels": "Merging Models",
"statusMergedModels": "Models Merged"
} }

View File

@ -60,5 +60,7 @@
"statusLoadingModel": "Loading Model", "statusLoadingModel": "Loading Model",
"statusModelChanged": "Model Changed", "statusModelChanged": "Model Changed",
"statusConvertingModel": "Converting Model", "statusConvertingModel": "Converting Model",
"statusModelConverted": "Model Converted" "statusModelConverted": "Model Converted",
"statusMergingModels": "Merging Models",
"statusMergedModels": "Models Merged"
} }

View File

@ -22,7 +22,7 @@
"config": "Config", "config": "Config",
"configValidationMsg": "Path to the config file of your model.", "configValidationMsg": "Path to the config file of your model.",
"modelLocation": "Model Location", "modelLocation": "Model Location",
"modelLocationValidationMsg": "Path to where your model is located.", "modelLocationValidationMsg": "Path to where your model is located locally.",
"repo_id": "Repo ID", "repo_id": "Repo ID",
"repoIDValidationMsg": "Online repository of your model", "repoIDValidationMsg": "Online repository of your model",
"vaeLocation": "VAE Location", "vaeLocation": "VAE Location",
@ -72,14 +72,33 @@
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.", "convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.", "convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?", "convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location",
"v1": "v1", "v1": "v1",
"v2": "v2", "v2": "v2",
"inpainting": "v1 Inpainting", "inpainting": "v1 Inpainting",
"customConfig": "Custom Config", "customConfig": "Custom Config",
"pathToCustomConfig": "Path To Custom Config", "pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting", "statusConverting": "Converting",
"sameFolder": "Same Folder", "modelConverted": "Model Converted",
"invokeRoot": "Invoke Models", "sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder",
"custom": "Custom", "custom": "Custom",
"customSaveLocation": "Custom Save Location" "customSaveLocation": "Custom Save Location",
"merge": "Merge",
"modelsMerged": "Models Merged",
"mergeModels": "Merge Models",
"modelOne": "Model 1",
"modelTwo": "Model 2",
"modelThree": "Model 3",
"mergedModelName": "Merged Model Name",
"alpha": "Alpha",
"interpolationType": "Interpolation Type",
"mergedModelSaveLocation": "Save Location",
"mergedModelCustomSaveLocation": "Custom Path",
"invokeAIFolder": "Invoke AI Folder",
"ignoreMismatch": "Ignore Mismatches Between Selected Models",
"modelMergeHeaderHelp1": "You can merge upto three different models to create a blend that suits your needs.",
"modelMergeHeaderHelp2": "Only Diffusers are available for merging. If you want to merge a checkpoint model, please convert it to Diffusers first.",
"modelMergeAlphaHelp": "Alpha controls blend strength for the models. Lower alpha values lead to lower influence of the second model.",
"modelMergeInterpAddDifferenceHelp": "In this mode, Model 3 is first subtracted from Model 2. The resulting version is blended with Model 1 with the alpha rate set above."
} }

View File

@ -76,12 +76,29 @@
"v1": "v1", "v1": "v1",
"v2": "v2", "v2": "v2",
"inpainting": "v1 Inpainting", "inpainting": "v1 Inpainting",
"customConfig": "Custom Config", "customConfig": "Custom Config",
"pathToCustomConfig": "Path To Custom Config", "pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting", "statusConverting": "Converting",
"modelConverted": "Model Converted", "modelConverted": "Model Converted",
"sameFolder": "Same folder", "sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder", "invokeRoot": "InvokeAI folder",
"custom": "Custom", "custom": "Custom",
"customSaveLocation": "Custom Save Location" "customSaveLocation": "Custom Save Location",
"merge": "Merge",
"modelsMerged": "Models Merged",
"mergeModels": "Merge Models",
"modelOne": "Model 1",
"modelTwo": "Model 2",
"modelThree": "Model 3",
"mergedModelName": "Merged Model Name",
"alpha": "Alpha",
"interpolationType": "Interpolation Type",
"mergedModelSaveLocation": "Save Location",
"mergedModelCustomSaveLocation": "Custom Path",
"invokeAIFolder": "Invoke AI Folder",
"ignoreMismatch": "Ignore Mismatches Between Selected Models",
"modelMergeHeaderHelp1": "You can merge upto three different models to create a blend that suits your needs.",
"modelMergeHeaderHelp2": "Only Diffusers are available for merging. If you want to merge a checkpoint model, please convert it to Diffusers first.",
"modelMergeAlphaHelp": "Alpha controls blend strength for the models. Lower alpha values lead to lower influence of the second model.",
"modelMergeInterpAddDifferenceHelp": "In this mode, Model 3 is first subtracted from Model 2. The resulting version is blended with Model 1 with the alpha rate set above."
} }

View File

@ -22,6 +22,7 @@
"upscaling": "Upscaling", "upscaling": "Upscaling",
"upscale": "Upscale", "upscale": "Upscale",
"upscaleImage": "Upscale Image", "upscaleImage": "Upscale Image",
"denoisingStrength": "Denoising Strength",
"scale": "Scale", "scale": "Scale",
"otherOptions": "Other Options", "otherOptions": "Other Options",
"seamlessTiling": "Seamless Tiling", "seamlessTiling": "Seamless Tiling",
@ -46,9 +47,11 @@
"invoke": "Invoke", "invoke": "Invoke",
"cancel": "Cancel", "cancel": "Cancel",
"promptPlaceholder": "Type prompt here. [negative tokens], (upweight)++, (downweight)--, swap and blend are available (see docs)", "promptPlaceholder": "Type prompt here. [negative tokens], (upweight)++, (downweight)--, swap and blend are available (see docs)",
"negativePrompts": "Negative Prompts",
"sendTo": "Send to", "sendTo": "Send to",
"sendToImg2Img": "Send to Image to Image", "sendToImg2Img": "Send to Image to Image",
"sendToUnifiedCanvas": "Send To Unified Canvas", "sendToUnifiedCanvas": "Send To Unified Canvas",
"copyImage": "Copy Image",
"copyImageToLink": "Copy Image To Link", "copyImageToLink": "Copy Image To Link",
"downloadImage": "Download Image", "downloadImage": "Download Image",
"openInViewer": "Open In Viewer", "openInViewer": "Open In Viewer",

View File

@ -225,6 +225,15 @@ export declare type InvokeModelConversionProps = {
custom_location: string | null; custom_location: string | null;
}; };
export declare type InvokeModelMergingProps = {
models_to_merge: string[];
alpha: number;
interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
force: boolean;
merged_model_name: string;
model_merge_save_path: string | null;
};
/** /**
* These types type data received from the server via socketio. * These types type data received from the server via socketio.
*/ */
@ -239,6 +248,12 @@ export declare type ModelConvertedResponse = {
model_list: ModelList; model_list: ModelList;
}; };
export declare type ModelsMergedResponse = {
merged_models: string[];
merged_model_name: string;
model_list: ModelList;
};
export declare type ModelAddedResponse = { export declare type ModelAddedResponse = {
new_model_name: string; new_model_name: string;
model_list: ModelList; model_list: ModelList;

View File

@ -43,6 +43,11 @@ export const convertToDiffusers =
'socketio/convertToDiffusers' 'socketio/convertToDiffusers'
); );
export const mergeDiffusersModels =
createAction<InvokeAI.InvokeModelMergingProps>(
'socketio/mergeDiffusersModels'
);
export const requestModelChange = createAction<string>( export const requestModelChange = createAction<string>(
'socketio/requestModelChange' 'socketio/requestModelChange'
); );

View File

@ -16,6 +16,7 @@ import {
generationRequested, generationRequested,
modelChangeRequested, modelChangeRequested,
modelConvertRequested, modelConvertRequested,
modelMergingRequested,
setIsProcessing, setIsProcessing,
} from 'features/system/store/systemSlice'; } from 'features/system/store/systemSlice';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
@ -185,6 +186,12 @@ const makeSocketIOEmitters = (
dispatch(modelConvertRequested()); dispatch(modelConvertRequested());
socketio.emit('convertToDiffusers', modelToConvert); socketio.emit('convertToDiffusers', modelToConvert);
}, },
emitMergeDiffusersModels: (
modelMergeInfo: InvokeAI.InvokeModelMergingProps
) => {
dispatch(modelMergingRequested());
socketio.emit('mergeDiffusersModels', modelMergeInfo);
},
emitRequestModelChange: (modelName: string) => { emitRequestModelChange: (modelName: string) => {
dispatch(modelChangeRequested()); dispatch(modelChangeRequested());
socketio.emit('requestModelChange', modelName); socketio.emit('requestModelChange', modelName);

View File

@ -432,6 +432,28 @@ const makeSocketIOListeners = (
}) })
); );
}, },
onModelsMerged: (data: InvokeAI.ModelsMergedResponse) => {
const { merged_models, merged_model_name, model_list } = data;
dispatch(setModelList(model_list));
dispatch(setCurrentStatus(i18n.t('common:statusMergedModels')));
dispatch(setIsProcessing(false));
dispatch(setIsCancelable(true));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Models merged: ${merged_models}`,
level: 'info',
})
);
dispatch(
addToast({
title: `${i18n.t('modelmanager:modelsMerged')}: ${merged_model_name}`,
status: 'success',
duration: 2500,
isClosable: true,
})
);
},
onModelChanged: (data: InvokeAI.ModelChangeResponse) => { onModelChanged: (data: InvokeAI.ModelChangeResponse) => {
const { model_name, model_list } = data; const { model_name, model_list } = data;
dispatch(setModelList(model_list)); dispatch(setModelList(model_list));

View File

@ -49,6 +49,7 @@ export const socketioMiddleware = () => {
onNewModelAdded, onNewModelAdded,
onModelDeleted, onModelDeleted,
onModelConverted, onModelConverted,
onModelsMerged,
onModelChangeFailed, onModelChangeFailed,
onTempFolderEmptied, onTempFolderEmptied,
} = makeSocketIOListeners(store); } = makeSocketIOListeners(store);
@ -66,6 +67,7 @@ export const socketioMiddleware = () => {
emitAddNewModel, emitAddNewModel,
emitDeleteModel, emitDeleteModel,
emitConvertToDiffusers, emitConvertToDiffusers,
emitMergeDiffusersModels,
emitRequestModelChange, emitRequestModelChange,
emitSaveStagingAreaImageToGallery, emitSaveStagingAreaImageToGallery,
emitRequestEmptyTempFolder, emitRequestEmptyTempFolder,
@ -131,6 +133,10 @@ export const socketioMiddleware = () => {
onModelConverted(data); onModelConverted(data);
}); });
socketio.on('modelsMerged', (data: InvokeAI.ModelsMergedResponse) => {
onModelsMerged(data);
});
socketio.on('modelChanged', (data: InvokeAI.ModelChangeResponse) => { socketio.on('modelChanged', (data: InvokeAI.ModelChangeResponse) => {
onModelChanged(data); onModelChanged(data);
}); });
@ -210,6 +216,11 @@ export const socketioMiddleware = () => {
break; break;
} }
case 'socketio/mergeDiffusersModels': {
emitMergeDiffusersModels(action.payload);
break;
}
case 'socketio/requestModelChange': { case 'socketio/requestModelChange': {
emitRequestModelChange(action.payload); emitRequestModelChange(action.payload);
break; break;

View File

@ -87,7 +87,11 @@ export default function AddModel() {
closeOnOverlayClick={false} closeOnOverlayClick={false}
> >
<ModalOverlay /> <ModalOverlay />
<ModalContent className="modal add-model-modal" fontFamily="Inter"> <ModalContent
className="modal add-model-modal"
fontFamily="Inter"
margin="auto"
>
<ModalHeader>{t('modelmanager:addNewModel')}</ModalHeader> <ModalHeader>{t('modelmanager:addNewModel')}</ModalHeader>
<ModalCloseButton marginTop="0.3rem" /> <ModalCloseButton marginTop="0.3rem" />
<ModalBody className="add-model-modal-body"> <ModalBody className="add-model-modal-body">

View File

@ -0,0 +1,293 @@
import {
Flex,
Modal,
ModalCloseButton,
ModalContent,
ModalHeader,
ModalOverlay,
Radio,
RadioGroup,
Text,
Tooltip,
useDisclosure,
} from '@chakra-ui/react';
import { mergeDiffusersModels } from 'app/socketio/actions';
import { type RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAISelect from 'common/components/IAISelect';
import { diffusersModelsSelector } from 'features/system/store/systemSelectors';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import * as InvokeAI from 'app/invokeai';
import IAISlider from 'common/components/IAISlider';
import IAICheckbox from 'common/components/IAICheckbox';
export default function MergeModels() {
const dispatch = useAppDispatch();
const { isOpen, onOpen, onClose } = useDisclosure();
const diffusersModels = useAppSelector(diffusersModelsSelector);
const { t } = useTranslation();
const [modelOne, setModelOne] = useState<string>(
Object.keys(diffusersModels)[0]
);
const [modelTwo, setModelTwo] = useState<string>(
Object.keys(diffusersModels)[1]
);
const [modelThree, setModelThree] = useState<string>('none');
const [mergedModelName, setMergedModelName] = useState<string>('');
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
const [modelMergeInterp, setModelMergeInterp] = useState<
'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
>('weighted_sum');
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<
'root' | 'custom'
>('root');
const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] =
useState<string>('');
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
const modelOneList = Object.keys(diffusersModels).filter((model) => {
if (model !== modelTwo && model !== modelThree) return model;
});
const modelTwoList = Object.keys(diffusersModels).filter((model) => {
if (model !== modelOne && model !== modelThree) return model;
});
const modelThreeList = [
'none',
...Object.keys(diffusersModels).filter((model) => {
if (model !== modelOne && model !== modelTwo) return model;
}),
];
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const mergeModelsHandler = () => {
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree];
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = {
models_to_merge: modelsToMerge,
merged_model_name:
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'),
alpha: modelMergeAlpha,
interp: modelMergeInterp,
model_merge_save_path:
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
force: modelMergeForce,
};
dispatch(mergeDiffusersModels(mergeModelsInfo));
};
return (
<>
<IAIButton onClick={onOpen} className="modal-close-btn" size="sm">
<Flex columnGap="0.5rem" alignItems="center">
{t('modelmanager:mergeModels')}
</Flex>
</IAIButton>
<Modal
isOpen={isOpen}
onClose={onClose}
size="4xl"
closeOnOverlayClick={false}
>
<ModalOverlay />
<ModalContent className="modal" fontFamily="Inter" margin="auto">
<ModalHeader>{t('modelmanager:mergeModels')}</ModalHeader>
<ModalCloseButton />
<Flex flexDirection="column" padding="1rem" rowGap={4}>
<Flex
flexDirection="column"
marginBottom="1rem"
padding="1rem"
borderRadius="0.3rem"
backgroundColor="var(--background-color)"
rowGap={1}
>
<Text>{t('modelmanager:modelMergeHeaderHelp1')}</Text>
<Text fontSize="0.9rem" color="var(--text-color-secondary)">
{t('modelmanager:modelMergeHeaderHelp2')}
</Text>
</Flex>
<Flex columnGap={4}>
<IAISelect
label={t('modelmanager:modelOne')}
validValues={modelOneList}
onChange={(e) => setModelOne(e.target.value)}
/>
<IAISelect
label={t('modelmanager:modelTwo')}
validValues={modelTwoList}
onChange={(e) => setModelTwo(e.target.value)}
/>
<IAISelect
label={t('modelmanager:modelThree')}
validValues={modelThreeList}
onChange={(e) => {
if (e.target.value !== 'none') {
setModelThree(e.target.value);
setModelMergeInterp('add_difference');
} else {
setModelThree('none');
setModelMergeInterp('weighted_sum');
}
}}
/>
</Flex>
<IAIInput
label={t('modelmanager:mergedModelName')}
value={mergedModelName}
onChange={(e) => setMergedModelName(e.target.value)}
/>
<Flex
flexDir="column"
backgroundColor="var(--background-color)"
padding="1rem 1rem"
borderRadius="0.2rem"
rowGap={2}
>
<IAISlider
label={t('modelmanager:alpha')}
min={0.01}
max={0.99}
step={0.01}
value={modelMergeAlpha}
onChange={(v) => setModelMergeAlpha(v)}
withInput
withReset
handleReset={() => setModelMergeAlpha(0.5)}
withSliderMarks
sliderMarkRightOffset={-7}
/>
<Text fontSize="0.9rem" color="var(--text-color-secondary)">
{t('modelmanager:modelMergeAlphaHelp')}
</Text>
</Flex>
<Flex
columnGap={4}
backgroundColor="var(--background-color)"
padding="1rem 1rem"
borderRadius="0.2rem"
>
<Text
fontWeight="bold"
fontSize="0.9rem"
color="var(--text-color-secondary)"
>
{t('modelmanager:interpolationType')}
</Text>
<RadioGroup
value={modelMergeInterp}
onChange={(
v:
| 'weighted_sum'
| 'sigmoid'
| 'inv_sigmoid'
| 'add_difference'
) => setModelMergeInterp(v)}
>
<Flex columnGap={4}>
{modelThree === 'none' ? (
<>
<Radio value="weighted_sum">weighted_sum</Radio>
<Radio value="sigmoid">sigmoid</Radio>
<Radio value="inv_sigmoid">inv_sigmoid</Radio>
</>
) : (
<Radio value="add_difference">
<Tooltip
label={t(
'modelmanager:modelMergeInterpAddDifferenceHelp'
)}
>
add_difference
</Tooltip>
</Radio>
)}
</Flex>
</RadioGroup>
</Flex>
<Flex
gap={4}
flexDirection="column"
backgroundColor="var(--background-color)"
padding="1rem 1rem"
borderRadius="0.2rem"
>
<Flex columnGap={4}>
<Text
fontWeight="bold"
fontSize="0.9rem"
color="var(--text-color-secondary)"
>
{t('modelmanager:mergedModelSaveLocation')}
</Text>
<RadioGroup
value={modelMergeSaveLocType}
onChange={(v: 'root' | 'custom') =>
setModelMergeSaveLocType(v)
}
>
<Flex columnGap={4}>
<Radio value="root">
{t('modelmanager:invokeAIFolder')}
</Radio>
<Radio value="custom">{t('modelmanager:custom')}</Radio>
</Flex>
</RadioGroup>
</Flex>
{modelMergeSaveLocType === 'custom' && (
<IAIInput
label={t('modelmanager:mergedModelCustomSaveLocation')}
value={modelMergeCustomSaveLoc}
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
/>
)}
</Flex>
<IAICheckbox
label={t('modelmanager:ignoreMismatch')}
isChecked={modelMergeForce}
onChange={(e) => setModelMergeForce(e.target.checked)}
fontWeight="bold"
/>
<IAIButton
onClick={mergeModelsHandler}
isLoading={isProcessing}
isDisabled={
modelMergeSaveLocType === 'custom' &&
modelMergeCustomSaveLoc === ''
}
className="modal modal-close-btn"
>
{t('modelmanager:merge')}
</IAIButton>
</Flex>
</ModalContent>
</Modal>
</>
);
}

View File

@ -14,6 +14,7 @@ import { systemSelector } from 'features/system/store/systemSelectors';
import type { SystemState } from 'features/system/store/systemSlice'; import type { SystemState } from 'features/system/store/systemSlice';
import { isEqual, map } from 'lodash'; import { isEqual, map } from 'lodash';
import type { ChangeEvent, ReactNode } from 'react'; import type { ChangeEvent, ReactNode } from 'react';
import MergeModels from './MergeModels';
const modelListSelector = createSelector( const modelListSelector = createSelector(
systemSelector, systemSelector,
@ -181,7 +182,10 @@ const ModelList = () => {
<Text fontSize="1.4rem" fontWeight="bold"> <Text fontSize="1.4rem" fontWeight="bold">
{t('modelmanager:availableModels')} {t('modelmanager:availableModels')}
</Text> </Text>
<AddModel /> <Flex gap={2}>
<AddModel />
<MergeModels />
</Flex>
</Flex> </Flex>
<IAIInput <IAIInput

View File

@ -1,6 +1,6 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store'; import { RootState } from 'app/store';
import { isEqual, reduce } from 'lodash'; import { isEqual, reduce, pickBy } from 'lodash';
export const systemSelector = (state: RootState) => state.system; export const systemSelector = (state: RootState) => state.system;
@ -28,3 +28,23 @@ export const activeModelSelector = createSelector(
}, },
} }
); );
export const diffusersModelsSelector = createSelector(
systemSelector,
(system) => {
const { model_list } = system;
const diffusersModels = pickBy(model_list, (model, key) => {
if (model.format === 'diffusers') {
return { name: key, ...model };
}
});
return diffusersModels;
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);

View File

@ -220,6 +220,12 @@ export const systemSlice = createSlice({
state.isProcessing = true; state.isProcessing = true;
state.currentStatusHasSteps = false; state.currentStatusHasSteps = false;
}, },
modelMergingRequested: (state) => {
state.currentStatus = i18n.t('common:statusMergingModels');
state.isCancelable = false;
state.isProcessing = true;
state.currentStatusHasSteps = false;
},
setSaveIntermediatesInterval: (state, action: PayloadAction<number>) => { setSaveIntermediatesInterval: (state, action: PayloadAction<number>) => {
state.saveIntermediatesInterval = action.payload; state.saveIntermediatesInterval = action.payload;
}, },
@ -272,6 +278,7 @@ export const {
setIsCancelable, setIsCancelable,
modelChangeRequested, modelChangeRequested,
modelConvertRequested, modelConvertRequested,
modelMergingRequested,
setSaveIntermediatesInterval, setSaveIntermediatesInterval,
setEnableImageDebugging, setEnableImageDebugging,
generationRequested, generationRequested,

File diff suppressed because one or more lines are too long