[WebUI] Model Merging (#2699)

This PR brings Model Merging to the WebUI.

Inside the Model Manager, you can now find a new button called Merge
Models. Rest of it is self explanatory.


![firefox_BYCM4YNHEa](https://user-images.githubusercontent.com/54517381/219795631-dbb5c5c4-fc3a-4cdd-9549-18c2e5302835.png)
This commit is contained in:
blessedcoolant 2023-02-18 14:34:35 +13:00 committed by GitHub
commit 767012aec0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1240 additions and 769 deletions

View File

@ -31,6 +31,8 @@ from ldm.invoke.generator.inpaint import infill_methods
from ldm.invoke.globals import Globals, global_converted_ckpts_dir
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend
from ldm.invoke.globals import global_models_dir
from ldm.invoke.merge_diffusers import merge_diffusion_models
# Loading Arguments
opt = Args()
@ -205,11 +207,7 @@ class InvokeAIWebServer:
return make_response(response, 200)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
return make_response("Error uploading file", 500)
self.load_socketio_listeners(self.socketio)
@ -317,10 +315,7 @@ class InvokeAIWebServer:
'found_models': found_models},
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
self.handle_exceptions(e)
print("\n")
@socketio.on("addNewModel")
@ -350,11 +345,7 @@ class InvokeAIWebServer:
)
print(f">> New Model Added: {model_name}")
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
@socketio.on("deleteModel")
def handle_delete_model(model_name: str):
@ -370,11 +361,7 @@ class InvokeAIWebServer:
)
print(f">> Model Deleted: {model_name}")
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
@socketio.on("requestModelChange")
def handle_set_model(model_name: str):
@ -393,11 +380,7 @@ class InvokeAIWebServer:
{"model_name": model_name, "model_list": model_list},
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
@socketio.on('convertToDiffusers')
def convert_to_diffusers(model_to_convert: dict):
@ -428,10 +411,12 @@ class InvokeAIWebServer:
)
if model_to_convert['save_location'] == 'root':
diffusers_path = Path(global_converted_ckpts_dir(), f'{model_name}_diffusers')
diffusers_path = Path(
global_converted_ckpts_dir(), f'{model_name}_diffusers')
if model_to_convert['save_location'] == 'custom' and model_to_convert['custom_location'] is not None:
diffusers_path = Path(model_to_convert['custom_location'], f'{model_name}_diffusers')
diffusers_path = Path(
model_to_convert['custom_location'], f'{model_name}_diffusers')
if diffusers_path.exists():
shutil.rmtree(diffusers_path)
@ -454,11 +439,51 @@ class InvokeAIWebServer:
)
print(f">> Model Converted: {model_name}")
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
self.handle_exceptions(e)
traceback.print_exc()
print("\n")
@socketio.on('mergeDiffusersModels')
def merge_diffusers_models(model_merge_info: dict):
try:
models_to_merge = model_merge_info['models_to_merge']
model_ids_or_paths = [
self.generate.model_manager.model_name_or_path(x) for x in models_to_merge]
merged_pipe = merge_diffusion_models(
model_ids_or_paths, model_merge_info['alpha'], model_merge_info['interp'], model_merge_info['force'])
dump_path = global_models_dir() / 'merged_models'
if model_merge_info['model_merge_save_path'] is not None:
dump_path = Path(model_merge_info['model_merge_save_path'])
os.makedirs(dump_path, exist_ok=True)
dump_path = dump_path / model_merge_info['merged_model_name']
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
merged_model_config = dict(
model_name=model_merge_info['merged_model_name'],
description=f'Merge of models {", ".join(models_to_merge)}',
commit_to_conf=opt.conf
)
if vae := self.generate.model_manager.config[models_to_merge[0]].get("vae", None):
print(
f">> Using configured VAE assigned to {models_to_merge[0]}")
merged_model_config.update(vae=vae)
self.generate.model_manager.import_diffuser_model(
dump_path, **merged_model_config)
new_model_list = self.generate.model_manager.list_models()
socketio.emit(
"modelsMerged",
{"merged_models": models_to_merge,
"merged_model_name": model_merge_info['merged_model_name'],
"model_list": new_model_list, 'update': True},
)
print(f">> Models Merged: {models_to_merge}")
print(
f">> New Model Added: {model_merge_info['merged_model_name']}")
except Exception as e:
self.handle_exceptions(e)
@socketio.on("requestEmptyTempFolder")
def empty_temp_folder():
@ -479,11 +504,7 @@ class InvokeAIWebServer:
socketio.emit("tempFolderEmptied")
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
@socketio.on("requestSaveStagingAreaImageToGallery")
def save_temp_image_to_gallery(url):
@ -525,11 +546,7 @@ class InvokeAIWebServer:
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
@socketio.on("requestLatestImages")
def handle_request_latest_images(category, latest_mtime):
@ -595,11 +612,7 @@ class InvokeAIWebServer:
{"images": image_array, "category": category},
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
@socketio.on("requestImages")
def handle_request_images(category, earliest_mtime=None):
@ -674,11 +687,7 @@ class InvokeAIWebServer:
},
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
@socketio.on("generateImage")
def handle_generate_image_event(
@ -711,11 +720,7 @@ class InvokeAIWebServer:
facetool_parameters,
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
@socketio.on("runPostprocessing")
def handle_run_postprocessing(original_image, postprocessing_parameters):
@ -829,11 +834,7 @@ class InvokeAIWebServer:
},
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
@socketio.on("cancel")
def handle_cancel():
@ -858,11 +859,7 @@ class InvokeAIWebServer:
{"url": url, "uuid": uuid, "category": category},
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
# App Functions
def get_system_config(self):
@ -1312,11 +1309,7 @@ class InvokeAIWebServer:
# Clear the CUDA cache on an exception
self.empty_cuda_cache()
print(e)
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
def empty_cuda_cache(self):
if self.generate.device.type == "cuda":
@ -1423,11 +1416,7 @@ class InvokeAIWebServer:
return metadata
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
def parameters_to_post_processed_image_metadata(
self, parameters, original_image_path
@ -1480,11 +1469,7 @@ class InvokeAIWebServer:
return current_metadata
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
def save_result_image(
self,
@ -1528,11 +1513,7 @@ class InvokeAIWebServer:
return os.path.abspath(path)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
def make_unique_init_image_filename(self, name):
try:
@ -1541,11 +1522,7 @@ class InvokeAIWebServer:
name = f"{split[0]}.{uuid}{split[1]}"
return name
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
def calculate_real_steps(self, steps, strength, has_init_image):
import math
@ -1560,11 +1537,7 @@ class InvokeAIWebServer:
file.writelines(message)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
def get_image_path_from_url(self, url):
"""Given a url to an image used by the client, returns the absolute file path to that image"""
@ -1595,11 +1568,7 @@ class InvokeAIWebServer:
os.path.join(self.result_path, os.path.basename(url))
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
def get_url_from_image_path(self, path):
"""Given an absolute file path to an image, returns the URL that the client can use to load the image"""
@ -1617,11 +1586,7 @@ class InvokeAIWebServer:
else:
return os.path.join(self.result_url, os.path.basename(path))
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
self.handle_exceptions(e)
def save_file_unique_uuid_name(self, bytes, name, path):
try:
@ -1640,11 +1605,13 @@ class InvokeAIWebServer:
return file_path
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
self.handle_exceptions(e)
traceback.print_exc()
print("\n")
def handle_exceptions(self, exception, emit_key: str = 'error'):
self.socketio.emit(emit_key, {"message": (str(exception))})
print("\n")
traceback.print_exc()
print("\n")
class Progress:

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
<script type="module" crossorigin src="./assets/index-9237ac63.js"></script>
<script type="module" crossorigin src="./assets/index-7062a172.js"></script>
<link rel="stylesheet" href="./assets/index-14cb2922.css">
</head>

View File

@ -9,6 +9,19 @@
"darkTheme": "Dark",
"lightTheme": "Light",
"greenTheme": "Green",
"langArabic": "العربية",
"langEnglish": "English",
"langDutch": "Nederlands",
"langFrench": "Français",
"langGerman": "Deutsch",
"langItalian": "Italiano",
"langJapanese": "日本語",
"langPolish": "Polski",
"langBrPortuguese": "Português do Brasil",
"langRussian": "Русский",
"langSimplifiedChinese": "简体中文",
"langUkranian": "Украї́нська",
"langSpanish": "Español",
"text2img": "Text To Image",
"img2img": "Image To Image",
"unifiedCanvas": "Unified Canvas",
@ -45,5 +58,9 @@
"statusUpscaling": "Upscaling",
"statusUpscalingESRGAN": "Upscaling (ESRGAN)",
"statusLoadingModel": "Loading Model",
"statusModelChanged": "Model Changed"
"statusModelChanged": "Model Changed",
"statusConvertingModel": "Converting Model",
"statusModelConverted": "Model Converted",
"statusMergingModels": "Merging Models",
"statusMergedModels": "Models Merged"
}

View File

@ -60,5 +60,7 @@
"statusLoadingModel": "Loading Model",
"statusModelChanged": "Model Changed",
"statusConvertingModel": "Converting Model",
"statusModelConverted": "Model Converted"
"statusModelConverted": "Model Converted",
"statusMergingModels": "Merging Models",
"statusMergedModels": "Models Merged"
}

View File

@ -22,7 +22,7 @@
"config": "Config",
"configValidationMsg": "Path to the config file of your model.",
"modelLocation": "Model Location",
"modelLocationValidationMsg": "Path to where your model is located.",
"modelLocationValidationMsg": "Path to where your model is located locally.",
"repo_id": "Repo ID",
"repoIDValidationMsg": "Online repository of your model",
"vaeLocation": "VAE Location",
@ -72,14 +72,33 @@
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location",
"v1": "v1",
"v2": "v2",
"inpainting": "v1 Inpainting",
"customConfig": "Custom Config",
"pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting",
"sameFolder": "Same Folder",
"invokeRoot": "Invoke Models",
"modelConverted": "Model Converted",
"sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder",
"custom": "Custom",
"customSaveLocation": "Custom Save Location"
"customSaveLocation": "Custom Save Location",
"merge": "Merge",
"modelsMerged": "Models Merged",
"mergeModels": "Merge Models",
"modelOne": "Model 1",
"modelTwo": "Model 2",
"modelThree": "Model 3",
"mergedModelName": "Merged Model Name",
"alpha": "Alpha",
"interpolationType": "Interpolation Type",
"mergedModelSaveLocation": "Save Location",
"mergedModelCustomSaveLocation": "Custom Path",
"invokeAIFolder": "Invoke AI Folder",
"ignoreMismatch": "Ignore Mismatches Between Selected Models",
"modelMergeHeaderHelp1": "You can merge upto three different models to create a blend that suits your needs.",
"modelMergeHeaderHelp2": "Only Diffusers are available for merging. If you want to merge a checkpoint model, please convert it to Diffusers first.",
"modelMergeAlphaHelp": "Alpha controls blend strength for the models. Lower alpha values lead to lower influence of the second model.",
"modelMergeInterpAddDifferenceHelp": "In this mode, Model 3 is first subtracted from Model 2. The resulting version is blended with Model 1 with the alpha rate set above."
}

View File

@ -83,5 +83,22 @@
"sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder",
"custom": "Custom",
"customSaveLocation": "Custom Save Location"
"customSaveLocation": "Custom Save Location",
"merge": "Merge",
"modelsMerged": "Models Merged",
"mergeModels": "Merge Models",
"modelOne": "Model 1",
"modelTwo": "Model 2",
"modelThree": "Model 3",
"mergedModelName": "Merged Model Name",
"alpha": "Alpha",
"interpolationType": "Interpolation Type",
"mergedModelSaveLocation": "Save Location",
"mergedModelCustomSaveLocation": "Custom Path",
"invokeAIFolder": "Invoke AI Folder",
"ignoreMismatch": "Ignore Mismatches Between Selected Models",
"modelMergeHeaderHelp1": "You can merge upto three different models to create a blend that suits your needs.",
"modelMergeHeaderHelp2": "Only Diffusers are available for merging. If you want to merge a checkpoint model, please convert it to Diffusers first.",
"modelMergeAlphaHelp": "Alpha controls blend strength for the models. Lower alpha values lead to lower influence of the second model.",
"modelMergeInterpAddDifferenceHelp": "In this mode, Model 3 is first subtracted from Model 2. The resulting version is blended with Model 1 with the alpha rate set above."
}

View File

@ -22,6 +22,7 @@
"upscaling": "Upscaling",
"upscale": "Upscale",
"upscaleImage": "Upscale Image",
"denoisingStrength": "Denoising Strength",
"scale": "Scale",
"otherOptions": "Other Options",
"seamlessTiling": "Seamless Tiling",
@ -46,9 +47,11 @@
"invoke": "Invoke",
"cancel": "Cancel",
"promptPlaceholder": "Type prompt here. [negative tokens], (upweight)++, (downweight)--, swap and blend are available (see docs)",
"negativePrompts": "Negative Prompts",
"sendTo": "Send to",
"sendToImg2Img": "Send to Image to Image",
"sendToUnifiedCanvas": "Send To Unified Canvas",
"copyImage": "Copy Image",
"copyImageToLink": "Copy Image To Link",
"downloadImage": "Download Image",
"openInViewer": "Open In Viewer",

View File

@ -9,6 +9,19 @@
"darkTheme": "Dark",
"lightTheme": "Light",
"greenTheme": "Green",
"langArabic": "العربية",
"langEnglish": "English",
"langDutch": "Nederlands",
"langFrench": "Français",
"langGerman": "Deutsch",
"langItalian": "Italiano",
"langJapanese": "日本語",
"langPolish": "Polski",
"langBrPortuguese": "Português do Brasil",
"langRussian": "Русский",
"langSimplifiedChinese": "简体中文",
"langUkranian": "Украї́нська",
"langSpanish": "Español",
"text2img": "Text To Image",
"img2img": "Image To Image",
"unifiedCanvas": "Unified Canvas",
@ -45,5 +58,9 @@
"statusUpscaling": "Upscaling",
"statusUpscalingESRGAN": "Upscaling (ESRGAN)",
"statusLoadingModel": "Loading Model",
"statusModelChanged": "Model Changed"
"statusModelChanged": "Model Changed",
"statusConvertingModel": "Converting Model",
"statusModelConverted": "Model Converted",
"statusMergingModels": "Merging Models",
"statusMergedModels": "Models Merged"
}

View File

@ -60,5 +60,7 @@
"statusLoadingModel": "Loading Model",
"statusModelChanged": "Model Changed",
"statusConvertingModel": "Converting Model",
"statusModelConverted": "Model Converted"
"statusModelConverted": "Model Converted",
"statusMergingModels": "Merging Models",
"statusMergedModels": "Models Merged"
}

View File

@ -22,7 +22,7 @@
"config": "Config",
"configValidationMsg": "Path to the config file of your model.",
"modelLocation": "Model Location",
"modelLocationValidationMsg": "Path to where your model is located.",
"modelLocationValidationMsg": "Path to where your model is located locally.",
"repo_id": "Repo ID",
"repoIDValidationMsg": "Online repository of your model",
"vaeLocation": "VAE Location",
@ -72,14 +72,33 @@
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location",
"v1": "v1",
"v2": "v2",
"inpainting": "v1 Inpainting",
"customConfig": "Custom Config",
"pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting",
"sameFolder": "Same Folder",
"invokeRoot": "Invoke Models",
"modelConverted": "Model Converted",
"sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder",
"custom": "Custom",
"customSaveLocation": "Custom Save Location"
"customSaveLocation": "Custom Save Location",
"merge": "Merge",
"modelsMerged": "Models Merged",
"mergeModels": "Merge Models",
"modelOne": "Model 1",
"modelTwo": "Model 2",
"modelThree": "Model 3",
"mergedModelName": "Merged Model Name",
"alpha": "Alpha",
"interpolationType": "Interpolation Type",
"mergedModelSaveLocation": "Save Location",
"mergedModelCustomSaveLocation": "Custom Path",
"invokeAIFolder": "Invoke AI Folder",
"ignoreMismatch": "Ignore Mismatches Between Selected Models",
"modelMergeHeaderHelp1": "You can merge upto three different models to create a blend that suits your needs.",
"modelMergeHeaderHelp2": "Only Diffusers are available for merging. If you want to merge a checkpoint model, please convert it to Diffusers first.",
"modelMergeAlphaHelp": "Alpha controls blend strength for the models. Lower alpha values lead to lower influence of the second model.",
"modelMergeInterpAddDifferenceHelp": "In this mode, Model 3 is first subtracted from Model 2. The resulting version is blended with Model 1 with the alpha rate set above."
}

View File

@ -83,5 +83,22 @@
"sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder",
"custom": "Custom",
"customSaveLocation": "Custom Save Location"
"customSaveLocation": "Custom Save Location",
"merge": "Merge",
"modelsMerged": "Models Merged",
"mergeModels": "Merge Models",
"modelOne": "Model 1",
"modelTwo": "Model 2",
"modelThree": "Model 3",
"mergedModelName": "Merged Model Name",
"alpha": "Alpha",
"interpolationType": "Interpolation Type",
"mergedModelSaveLocation": "Save Location",
"mergedModelCustomSaveLocation": "Custom Path",
"invokeAIFolder": "Invoke AI Folder",
"ignoreMismatch": "Ignore Mismatches Between Selected Models",
"modelMergeHeaderHelp1": "You can merge upto three different models to create a blend that suits your needs.",
"modelMergeHeaderHelp2": "Only Diffusers are available for merging. If you want to merge a checkpoint model, please convert it to Diffusers first.",
"modelMergeAlphaHelp": "Alpha controls blend strength for the models. Lower alpha values lead to lower influence of the second model.",
"modelMergeInterpAddDifferenceHelp": "In this mode, Model 3 is first subtracted from Model 2. The resulting version is blended with Model 1 with the alpha rate set above."
}

View File

@ -22,6 +22,7 @@
"upscaling": "Upscaling",
"upscale": "Upscale",
"upscaleImage": "Upscale Image",
"denoisingStrength": "Denoising Strength",
"scale": "Scale",
"otherOptions": "Other Options",
"seamlessTiling": "Seamless Tiling",
@ -46,9 +47,11 @@
"invoke": "Invoke",
"cancel": "Cancel",
"promptPlaceholder": "Type prompt here. [negative tokens], (upweight)++, (downweight)--, swap and blend are available (see docs)",
"negativePrompts": "Negative Prompts",
"sendTo": "Send to",
"sendToImg2Img": "Send to Image to Image",
"sendToUnifiedCanvas": "Send To Unified Canvas",
"copyImage": "Copy Image",
"copyImageToLink": "Copy Image To Link",
"downloadImage": "Download Image",
"openInViewer": "Open In Viewer",

View File

@ -225,6 +225,15 @@ export declare type InvokeModelConversionProps = {
custom_location: string | null;
};
export declare type InvokeModelMergingProps = {
models_to_merge: string[];
alpha: number;
interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
force: boolean;
merged_model_name: string;
model_merge_save_path: string | null;
};
/**
* These types type data received from the server via socketio.
*/
@ -239,6 +248,12 @@ export declare type ModelConvertedResponse = {
model_list: ModelList;
};
export declare type ModelsMergedResponse = {
merged_models: string[];
merged_model_name: string;
model_list: ModelList;
};
export declare type ModelAddedResponse = {
new_model_name: string;
model_list: ModelList;

View File

@ -43,6 +43,11 @@ export const convertToDiffusers =
'socketio/convertToDiffusers'
);
export const mergeDiffusersModels =
createAction<InvokeAI.InvokeModelMergingProps>(
'socketio/mergeDiffusersModels'
);
export const requestModelChange = createAction<string>(
'socketio/requestModelChange'
);

View File

@ -16,6 +16,7 @@ import {
generationRequested,
modelChangeRequested,
modelConvertRequested,
modelMergingRequested,
setIsProcessing,
} from 'features/system/store/systemSlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
@ -185,6 +186,12 @@ const makeSocketIOEmitters = (
dispatch(modelConvertRequested());
socketio.emit('convertToDiffusers', modelToConvert);
},
emitMergeDiffusersModels: (
modelMergeInfo: InvokeAI.InvokeModelMergingProps
) => {
dispatch(modelMergingRequested());
socketio.emit('mergeDiffusersModels', modelMergeInfo);
},
emitRequestModelChange: (modelName: string) => {
dispatch(modelChangeRequested());
socketio.emit('requestModelChange', modelName);

View File

@ -432,6 +432,28 @@ const makeSocketIOListeners = (
})
);
},
onModelsMerged: (data: InvokeAI.ModelsMergedResponse) => {
const { merged_models, merged_model_name, model_list } = data;
dispatch(setModelList(model_list));
dispatch(setCurrentStatus(i18n.t('common:statusMergedModels')));
dispatch(setIsProcessing(false));
dispatch(setIsCancelable(true));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Models merged: ${merged_models}`,
level: 'info',
})
);
dispatch(
addToast({
title: `${i18n.t('modelmanager:modelsMerged')}: ${merged_model_name}`,
status: 'success',
duration: 2500,
isClosable: true,
})
);
},
onModelChanged: (data: InvokeAI.ModelChangeResponse) => {
const { model_name, model_list } = data;
dispatch(setModelList(model_list));

View File

@ -49,6 +49,7 @@ export const socketioMiddleware = () => {
onNewModelAdded,
onModelDeleted,
onModelConverted,
onModelsMerged,
onModelChangeFailed,
onTempFolderEmptied,
} = makeSocketIOListeners(store);
@ -66,6 +67,7 @@ export const socketioMiddleware = () => {
emitAddNewModel,
emitDeleteModel,
emitConvertToDiffusers,
emitMergeDiffusersModels,
emitRequestModelChange,
emitSaveStagingAreaImageToGallery,
emitRequestEmptyTempFolder,
@ -131,6 +133,10 @@ export const socketioMiddleware = () => {
onModelConverted(data);
});
socketio.on('modelsMerged', (data: InvokeAI.ModelsMergedResponse) => {
onModelsMerged(data);
});
socketio.on('modelChanged', (data: InvokeAI.ModelChangeResponse) => {
onModelChanged(data);
});
@ -210,6 +216,11 @@ export const socketioMiddleware = () => {
break;
}
case 'socketio/mergeDiffusersModels': {
emitMergeDiffusersModels(action.payload);
break;
}
case 'socketio/requestModelChange': {
emitRequestModelChange(action.payload);
break;

View File

@ -87,7 +87,11 @@ export default function AddModel() {
closeOnOverlayClick={false}
>
<ModalOverlay />
<ModalContent className="modal add-model-modal" fontFamily="Inter">
<ModalContent
className="modal add-model-modal"
fontFamily="Inter"
margin="auto"
>
<ModalHeader>{t('modelmanager:addNewModel')}</ModalHeader>
<ModalCloseButton marginTop="0.3rem" />
<ModalBody className="add-model-modal-body">

View File

@ -0,0 +1,293 @@
import {
Flex,
Modal,
ModalCloseButton,
ModalContent,
ModalHeader,
ModalOverlay,
Radio,
RadioGroup,
Text,
Tooltip,
useDisclosure,
} from '@chakra-ui/react';
import { mergeDiffusersModels } from 'app/socketio/actions';
import { type RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAISelect from 'common/components/IAISelect';
import { diffusersModelsSelector } from 'features/system/store/systemSelectors';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import * as InvokeAI from 'app/invokeai';
import IAISlider from 'common/components/IAISlider';
import IAICheckbox from 'common/components/IAICheckbox';
export default function MergeModels() {
const dispatch = useAppDispatch();
const { isOpen, onOpen, onClose } = useDisclosure();
const diffusersModels = useAppSelector(diffusersModelsSelector);
const { t } = useTranslation();
const [modelOne, setModelOne] = useState<string>(
Object.keys(diffusersModels)[0]
);
const [modelTwo, setModelTwo] = useState<string>(
Object.keys(diffusersModels)[1]
);
const [modelThree, setModelThree] = useState<string>('none');
const [mergedModelName, setMergedModelName] = useState<string>('');
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
const [modelMergeInterp, setModelMergeInterp] = useState<
'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
>('weighted_sum');
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<
'root' | 'custom'
>('root');
const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] =
useState<string>('');
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
const modelOneList = Object.keys(diffusersModels).filter((model) => {
if (model !== modelTwo && model !== modelThree) return model;
});
const modelTwoList = Object.keys(diffusersModels).filter((model) => {
if (model !== modelOne && model !== modelThree) return model;
});
const modelThreeList = [
'none',
...Object.keys(diffusersModels).filter((model) => {
if (model !== modelOne && model !== modelTwo) return model;
}),
];
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const mergeModelsHandler = () => {
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree];
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = {
models_to_merge: modelsToMerge,
merged_model_name:
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'),
alpha: modelMergeAlpha,
interp: modelMergeInterp,
model_merge_save_path:
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
force: modelMergeForce,
};
dispatch(mergeDiffusersModels(mergeModelsInfo));
};
return (
<>
<IAIButton onClick={onOpen} className="modal-close-btn" size="sm">
<Flex columnGap="0.5rem" alignItems="center">
{t('modelmanager:mergeModels')}
</Flex>
</IAIButton>
<Modal
isOpen={isOpen}
onClose={onClose}
size="4xl"
closeOnOverlayClick={false}
>
<ModalOverlay />
<ModalContent className="modal" fontFamily="Inter" margin="auto">
<ModalHeader>{t('modelmanager:mergeModels')}</ModalHeader>
<ModalCloseButton />
<Flex flexDirection="column" padding="1rem" rowGap={4}>
<Flex
flexDirection="column"
marginBottom="1rem"
padding="1rem"
borderRadius="0.3rem"
backgroundColor="var(--background-color)"
rowGap={1}
>
<Text>{t('modelmanager:modelMergeHeaderHelp1')}</Text>
<Text fontSize="0.9rem" color="var(--text-color-secondary)">
{t('modelmanager:modelMergeHeaderHelp2')}
</Text>
</Flex>
<Flex columnGap={4}>
<IAISelect
label={t('modelmanager:modelOne')}
validValues={modelOneList}
onChange={(e) => setModelOne(e.target.value)}
/>
<IAISelect
label={t('modelmanager:modelTwo')}
validValues={modelTwoList}
onChange={(e) => setModelTwo(e.target.value)}
/>
<IAISelect
label={t('modelmanager:modelThree')}
validValues={modelThreeList}
onChange={(e) => {
if (e.target.value !== 'none') {
setModelThree(e.target.value);
setModelMergeInterp('add_difference');
} else {
setModelThree('none');
setModelMergeInterp('weighted_sum');
}
}}
/>
</Flex>
<IAIInput
label={t('modelmanager:mergedModelName')}
value={mergedModelName}
onChange={(e) => setMergedModelName(e.target.value)}
/>
<Flex
flexDir="column"
backgroundColor="var(--background-color)"
padding="1rem 1rem"
borderRadius="0.2rem"
rowGap={2}
>
<IAISlider
label={t('modelmanager:alpha')}
min={0.01}
max={0.99}
step={0.01}
value={modelMergeAlpha}
onChange={(v) => setModelMergeAlpha(v)}
withInput
withReset
handleReset={() => setModelMergeAlpha(0.5)}
withSliderMarks
sliderMarkRightOffset={-7}
/>
<Text fontSize="0.9rem" color="var(--text-color-secondary)">
{t('modelmanager:modelMergeAlphaHelp')}
</Text>
</Flex>
<Flex
columnGap={4}
backgroundColor="var(--background-color)"
padding="1rem 1rem"
borderRadius="0.2rem"
>
<Text
fontWeight="bold"
fontSize="0.9rem"
color="var(--text-color-secondary)"
>
{t('modelmanager:interpolationType')}
</Text>
<RadioGroup
value={modelMergeInterp}
onChange={(
v:
| 'weighted_sum'
| 'sigmoid'
| 'inv_sigmoid'
| 'add_difference'
) => setModelMergeInterp(v)}
>
<Flex columnGap={4}>
{modelThree === 'none' ? (
<>
<Radio value="weighted_sum">weighted_sum</Radio>
<Radio value="sigmoid">sigmoid</Radio>
<Radio value="inv_sigmoid">inv_sigmoid</Radio>
</>
) : (
<Radio value="add_difference">
<Tooltip
label={t(
'modelmanager:modelMergeInterpAddDifferenceHelp'
)}
>
add_difference
</Tooltip>
</Radio>
)}
</Flex>
</RadioGroup>
</Flex>
<Flex
gap={4}
flexDirection="column"
backgroundColor="var(--background-color)"
padding="1rem 1rem"
borderRadius="0.2rem"
>
<Flex columnGap={4}>
<Text
fontWeight="bold"
fontSize="0.9rem"
color="var(--text-color-secondary)"
>
{t('modelmanager:mergedModelSaveLocation')}
</Text>
<RadioGroup
value={modelMergeSaveLocType}
onChange={(v: 'root' | 'custom') =>
setModelMergeSaveLocType(v)
}
>
<Flex columnGap={4}>
<Radio value="root">
{t('modelmanager:invokeAIFolder')}
</Radio>
<Radio value="custom">{t('modelmanager:custom')}</Radio>
</Flex>
</RadioGroup>
</Flex>
{modelMergeSaveLocType === 'custom' && (
<IAIInput
label={t('modelmanager:mergedModelCustomSaveLocation')}
value={modelMergeCustomSaveLoc}
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
/>
)}
</Flex>
<IAICheckbox
label={t('modelmanager:ignoreMismatch')}
isChecked={modelMergeForce}
onChange={(e) => setModelMergeForce(e.target.checked)}
fontWeight="bold"
/>
<IAIButton
onClick={mergeModelsHandler}
isLoading={isProcessing}
isDisabled={
modelMergeSaveLocType === 'custom' &&
modelMergeCustomSaveLoc === ''
}
className="modal modal-close-btn"
>
{t('modelmanager:merge')}
</IAIButton>
</Flex>
</ModalContent>
</Modal>
</>
);
}

View File

@ -14,6 +14,7 @@ import { systemSelector } from 'features/system/store/systemSelectors';
import type { SystemState } from 'features/system/store/systemSlice';
import { isEqual, map } from 'lodash';
import type { ChangeEvent, ReactNode } from 'react';
import MergeModels from './MergeModels';
const modelListSelector = createSelector(
systemSelector,
@ -181,7 +182,10 @@ const ModelList = () => {
<Text fontSize="1.4rem" fontWeight="bold">
{t('modelmanager:availableModels')}
</Text>
<AddModel />
<Flex gap={2}>
<AddModel />
<MergeModels />
</Flex>
</Flex>
<IAIInput

View File

@ -1,6 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { isEqual, reduce } from 'lodash';
import { isEqual, reduce, pickBy } from 'lodash';
export const systemSelector = (state: RootState) => state.system;
@ -28,3 +28,23 @@ export const activeModelSelector = createSelector(
},
}
);
export const diffusersModelsSelector = createSelector(
systemSelector,
(system) => {
const { model_list } = system;
const diffusersModels = pickBy(model_list, (model, key) => {
if (model.format === 'diffusers') {
return { name: key, ...model };
}
});
return diffusersModels;
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);

View File

@ -220,6 +220,12 @@ export const systemSlice = createSlice({
state.isProcessing = true;
state.currentStatusHasSteps = false;
},
modelMergingRequested: (state) => {
state.currentStatus = i18n.t('common:statusMergingModels');
state.isCancelable = false;
state.isProcessing = true;
state.currentStatusHasSteps = false;
},
setSaveIntermediatesInterval: (state, action: PayloadAction<number>) => {
state.saveIntermediatesInterval = action.payload;
},
@ -272,6 +278,7 @@ export const {
setIsCancelable,
modelChangeRequested,
modelConvertRequested,
modelMergingRequested,
setSaveIntermediatesInterval,
setEnableImageDebugging,
generationRequested,

File diff suppressed because one or more lines are too long