From ebe0071ed20b6d8731387abd7c2fb27d870eec4f Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sat, 18 Feb 2023 07:19:25 +1300 Subject: [PATCH] feat: [WebUI] Model Merging --- invokeai/backend/invoke_ai_web_server.py | 184 +++++------ .../frontend/public/locales/common/en-US.json | 19 +- .../frontend/public/locales/common/en.json | 4 +- .../public/locales/modelmanager/en-US.json | 27 +- .../public/locales/modelmanager/en.json | 21 +- .../public/locales/parameters/en-US.json | 3 + invokeai/frontend/src/app/invokeai.d.ts | 15 + invokeai/frontend/src/app/socketio/actions.ts | 5 + .../frontend/src/app/socketio/emitters.ts | 7 + .../frontend/src/app/socketio/listeners.ts | 22 ++ .../frontend/src/app/socketio/middleware.ts | 11 + .../components/ModelManager/AddModel.tsx | 6 +- .../components/ModelManager/MergeModels.tsx | 292 ++++++++++++++++++ .../components/ModelManager/ModelList.tsx | 6 +- .../features/system/store/systemSelectors.ts | 22 +- .../src/features/system/store/systemSlice.ts | 7 + 16 files changed, 530 insertions(+), 121 deletions(-) create mode 100644 invokeai/frontend/src/features/system/components/ModelManager/MergeModels.tsx diff --git a/invokeai/backend/invoke_ai_web_server.py b/invokeai/backend/invoke_ai_web_server.py index 7ea100db20..e84737aafc 100644 --- a/invokeai/backend/invoke_ai_web_server.py +++ b/invokeai/backend/invoke_ai_web_server.py @@ -31,6 +31,8 @@ from ldm.invoke.generator.inpaint import infill_methods from ldm.invoke.globals import Globals, global_converted_ckpts_dir from ldm.invoke.pngwriter import PngWriter, retrieve_metadata from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend +from ldm.invoke.globals import global_models_dir +from ldm.invoke.merge_diffusers import merge_diffusion_models # Loading Arguments opt = Args() @@ -205,11 +207,7 @@ class InvokeAIWebServer: return make_response(response, 200) except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) return make_response("Error uploading file", 500) self.load_socketio_listeners(self.socketio) @@ -317,10 +315,7 @@ class InvokeAIWebServer: 'found_models': found_models}, ) except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() + self.handle_exceptions(e) print("\n") @socketio.on("addNewModel") @@ -350,11 +345,7 @@ class InvokeAIWebServer: ) print(f">> New Model Added: {model_name}") except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) @socketio.on("deleteModel") def handle_delete_model(model_name: str): @@ -370,11 +361,7 @@ class InvokeAIWebServer: ) print(f">> Model Deleted: {model_name}") except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) @socketio.on("requestModelChange") def handle_set_model(model_name: str): @@ -393,11 +380,7 @@ class InvokeAIWebServer: {"model_name": model_name, "model_list": model_list}, ) except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) @socketio.on('convertToDiffusers') def convert_to_diffusers(model_to_convert: dict): @@ -428,10 +411,12 @@ class InvokeAIWebServer: ) if model_to_convert['save_location'] == 'root': - diffusers_path = Path(global_converted_ckpts_dir(), f'{model_name}_diffusers') - + diffusers_path = Path( + global_converted_ckpts_dir(), f'{model_name}_diffusers') + if model_to_convert['save_location'] == 'custom' and model_to_convert['custom_location'] is not None: - diffusers_path = Path(model_to_convert['custom_location'], f'{model_name}_diffusers') + diffusers_path = Path( + model_to_convert['custom_location'], f'{model_name}_diffusers') if diffusers_path.exists(): shutil.rmtree(diffusers_path) @@ -454,11 +439,48 @@ class InvokeAIWebServer: ) print(f">> Model Converted: {model_name}") except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") + self.handle_exceptions(e) - traceback.print_exc() - print("\n") + @socketio.on('mergeDiffusersModels') + def merge_diffusers_models(model_merge_info: dict): + models_to_merge = model_merge_info['models_to_merge'] + model_ids_or_paths = [ + self.generate.model_manager.model_name_or_path(x) for x in models_to_merge] + merged_pipe = merge_diffusion_models( + model_ids_or_paths, model_merge_info['alpha'], model_merge_info['interp'], model_merge_info['force']) + + dump_path = global_models_dir() / 'merged_models' + if model_merge_info['model_merge_save_path'] is not None: + dump_path = Path(model_merge_info['model_merge_save_path']) + + os.makedirs(dump_path, exist_ok=True) + dump_path = dump_path / model_merge_info['merged_model_name'] + merged_pipe.save_pretrained(dump_path, safe_serialization=1) + + merged_model_config = dict( + model_name=model_merge_info['merged_model_name'], + description=f'Merge of models {", ".join(models_to_merge)}', + commit_to_conf=opt.conf + ) + + if vae := self.generate.model_manager.config[models_to_merge[0]].get("vae", None): + print( + f">> Using configured VAE assigned to {models_to_merge[0]}") + merged_model_config.update(vae=vae) + + self.generate.model_manager.import_diffuser_model( + dump_path, **merged_model_config) + new_model_list = self.generate.model_manager.list_models() + + socketio.emit( + "modelsMerged", + {"merged_models": models_to_merge, + "merged_model_name": model_merge_info['merged_model_name'], + "model_list": new_model_list, 'update': True}, + ) + print(f">> Models Merged: {models_to_merge}") + print( + f">> New Model Added: {model_merge_info['merged_model_name']}") @socketio.on("requestEmptyTempFolder") def empty_temp_folder(): @@ -479,11 +501,7 @@ class InvokeAIWebServer: socketio.emit("tempFolderEmptied") except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) @socketio.on("requestSaveStagingAreaImageToGallery") def save_temp_image_to_gallery(url): @@ -525,11 +543,7 @@ class InvokeAIWebServer: ) except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) @socketio.on("requestLatestImages") def handle_request_latest_images(category, latest_mtime): @@ -595,11 +609,7 @@ class InvokeAIWebServer: {"images": image_array, "category": category}, ) except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) @socketio.on("requestImages") def handle_request_images(category, earliest_mtime=None): @@ -674,11 +684,7 @@ class InvokeAIWebServer: }, ) except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) @socketio.on("generateImage") def handle_generate_image_event( @@ -711,11 +717,7 @@ class InvokeAIWebServer: facetool_parameters, ) except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) @socketio.on("runPostprocessing") def handle_run_postprocessing(original_image, postprocessing_parameters): @@ -829,11 +831,7 @@ class InvokeAIWebServer: }, ) except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) @socketio.on("cancel") def handle_cancel(): @@ -858,11 +856,7 @@ class InvokeAIWebServer: {"url": url, "uuid": uuid, "category": category}, ) except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) # App Functions def get_system_config(self): @@ -1312,11 +1306,7 @@ class InvokeAIWebServer: # Clear the CUDA cache on an exception self.empty_cuda_cache() print(e) - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) def empty_cuda_cache(self): if self.generate.device.type == "cuda": @@ -1423,11 +1413,7 @@ class InvokeAIWebServer: return metadata except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) def parameters_to_post_processed_image_metadata( self, parameters, original_image_path @@ -1480,11 +1466,7 @@ class InvokeAIWebServer: return current_metadata except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) def save_result_image( self, @@ -1528,11 +1510,7 @@ class InvokeAIWebServer: return os.path.abspath(path) except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) def make_unique_init_image_filename(self, name): try: @@ -1541,11 +1519,7 @@ class InvokeAIWebServer: name = f"{split[0]}.{uuid}{split[1]}" return name except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) def calculate_real_steps(self, steps, strength, has_init_image): import math @@ -1560,11 +1534,7 @@ class InvokeAIWebServer: file.writelines(message) except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) def get_image_path_from_url(self, url): """Given a url to an image used by the client, returns the absolute file path to that image""" @@ -1595,11 +1565,7 @@ class InvokeAIWebServer: os.path.join(self.result_path, os.path.basename(url)) ) except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) def get_url_from_image_path(self, path): """Given an absolute file path to an image, returns the URL that the client can use to load the image""" @@ -1617,11 +1583,7 @@ class InvokeAIWebServer: else: return os.path.join(self.result_url, os.path.basename(path)) except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") - - traceback.print_exc() - print("\n") + self.handle_exceptions(e) def save_file_unique_uuid_name(self, bytes, name, path): try: @@ -1640,11 +1602,13 @@ class InvokeAIWebServer: return file_path except Exception as e: - self.socketio.emit("error", {"message": (str(e))}) - print("\n") + self.handle_exceptions(e) - traceback.print_exc() - print("\n") + def handle_exceptions(self, exception, emit_key: str = 'error'): + self.socketio.emit(emit_key, {"message": (str(exception))}) + print("\n") + traceback.print_exc() + print("\n") class Progress: diff --git a/invokeai/frontend/public/locales/common/en-US.json b/invokeai/frontend/public/locales/common/en-US.json index 00b60df420..c5ec98b6ae 100644 --- a/invokeai/frontend/public/locales/common/en-US.json +++ b/invokeai/frontend/public/locales/common/en-US.json @@ -9,6 +9,19 @@ "darkTheme": "Dark", "lightTheme": "Light", "greenTheme": "Green", + "langArabic": "العربية", + "langEnglish": "English", + "langDutch": "Nederlands", + "langFrench": "Français", + "langGerman": "Deutsch", + "langItalian": "Italiano", + "langJapanese": "日本語", + "langPolish": "Polski", + "langBrPortuguese": "Português do Brasil", + "langRussian": "Русский", + "langSimplifiedChinese": "简体中文", + "langUkranian": "Украї́нська", + "langSpanish": "Español", "text2img": "Text To Image", "img2img": "Image To Image", "unifiedCanvas": "Unified Canvas", @@ -45,5 +58,9 @@ "statusUpscaling": "Upscaling", "statusUpscalingESRGAN": "Upscaling (ESRGAN)", "statusLoadingModel": "Loading Model", - "statusModelChanged": "Model Changed" + "statusModelChanged": "Model Changed", + "statusConvertingModel": "Converting Model", + "statusModelConverted": "Model Converted", + "statusMergingModels": "Merging Models", + "statusMergedModels": "Models Merged" } diff --git a/invokeai/frontend/public/locales/common/en.json b/invokeai/frontend/public/locales/common/en.json index 556da133de..c5ec98b6ae 100644 --- a/invokeai/frontend/public/locales/common/en.json +++ b/invokeai/frontend/public/locales/common/en.json @@ -60,5 +60,7 @@ "statusLoadingModel": "Loading Model", "statusModelChanged": "Model Changed", "statusConvertingModel": "Converting Model", - "statusModelConverted": "Model Converted" + "statusModelConverted": "Model Converted", + "statusMergingModels": "Merging Models", + "statusMergedModels": "Models Merged" } diff --git a/invokeai/frontend/public/locales/modelmanager/en-US.json b/invokeai/frontend/public/locales/modelmanager/en-US.json index a58592bd2f..c13542d3f8 100644 --- a/invokeai/frontend/public/locales/modelmanager/en-US.json +++ b/invokeai/frontend/public/locales/modelmanager/en-US.json @@ -22,7 +22,7 @@ "config": "Config", "configValidationMsg": "Path to the config file of your model.", "modelLocation": "Model Location", - "modelLocationValidationMsg": "Path to where your model is located.", + "modelLocationValidationMsg": "Path to where your model is located locally.", "repo_id": "Repo ID", "repoIDValidationMsg": "Online repository of your model", "vaeLocation": "VAE Location", @@ -72,14 +72,33 @@ "convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.", "convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.", "convertToDiffusersHelpText6": "Do you wish to convert this model?", + "convertToDiffusersSaveLocation": "Save Location", "v1": "v1", "v2": "v2", "inpainting": "v1 Inpainting", "customConfig": "Custom Config", "pathToCustomConfig": "Path To Custom Config", "statusConverting": "Converting", - "sameFolder": "Same Folder", - "invokeRoot": "Invoke Models", + "modelConverted": "Model Converted", + "sameFolder": "Same folder", + "invokeRoot": "InvokeAI folder", "custom": "Custom", - "customSaveLocation": "Custom Save Location" + "customSaveLocation": "Custom Save Location", + "merge": "Merge", + "modelsMerged": "Models Merged", + "mergeModels": "Merge Models", + "modelOne": "Model 1", + "modelTwo": "Model 2", + "modelThree": "Model 3", + "mergedModelName": "Merged Model Name", + "alpha": "Alpha", + "interpolationType": "Interpolation Type", + "mergedModelSaveLocation": "Save Location", + "mergedModelCustomSaveLocation": "Custom Path", + "invokeAIFolder": "Invoke AI Folder", + "ignoreMismatch": "Ignore Mismatches Between Selected Models", + "modelMergeHeaderHelp1": "You can merge upto three different models to create a blend that suits your needs.", + "modelMergeHeaderHelp2": "Only Diffusers are available for merging. If you want to merge a checkpoint model, please convert it to Diffusers first.", + "modelMergeAlphaHelp": "Alpha controls blend strength for the models. Lower alpha values lead to lower influence of the second model.", + "modelMergeInterpAddDifferenceHelp": "In this mode, Model 3 is first subtracted from Model 2. The resulting version is blended with Model 1 with the alpha rate set above." } diff --git a/invokeai/frontend/public/locales/modelmanager/en.json b/invokeai/frontend/public/locales/modelmanager/en.json index be4830799f..c13542d3f8 100644 --- a/invokeai/frontend/public/locales/modelmanager/en.json +++ b/invokeai/frontend/public/locales/modelmanager/en.json @@ -76,12 +76,29 @@ "v1": "v1", "v2": "v2", "inpainting": "v1 Inpainting", - "customConfig": "Custom Config", + "customConfig": "Custom Config", "pathToCustomConfig": "Path To Custom Config", "statusConverting": "Converting", "modelConverted": "Model Converted", "sameFolder": "Same folder", "invokeRoot": "InvokeAI folder", "custom": "Custom", - "customSaveLocation": "Custom Save Location" + "customSaveLocation": "Custom Save Location", + "merge": "Merge", + "modelsMerged": "Models Merged", + "mergeModels": "Merge Models", + "modelOne": "Model 1", + "modelTwo": "Model 2", + "modelThree": "Model 3", + "mergedModelName": "Merged Model Name", + "alpha": "Alpha", + "interpolationType": "Interpolation Type", + "mergedModelSaveLocation": "Save Location", + "mergedModelCustomSaveLocation": "Custom Path", + "invokeAIFolder": "Invoke AI Folder", + "ignoreMismatch": "Ignore Mismatches Between Selected Models", + "modelMergeHeaderHelp1": "You can merge upto three different models to create a blend that suits your needs.", + "modelMergeHeaderHelp2": "Only Diffusers are available for merging. If you want to merge a checkpoint model, please convert it to Diffusers first.", + "modelMergeAlphaHelp": "Alpha controls blend strength for the models. Lower alpha values lead to lower influence of the second model.", + "modelMergeInterpAddDifferenceHelp": "In this mode, Model 3 is first subtracted from Model 2. The resulting version is blended with Model 1 with the alpha rate set above." } diff --git a/invokeai/frontend/public/locales/parameters/en-US.json b/invokeai/frontend/public/locales/parameters/en-US.json index d67a659e47..f1b91c3959 100644 --- a/invokeai/frontend/public/locales/parameters/en-US.json +++ b/invokeai/frontend/public/locales/parameters/en-US.json @@ -22,6 +22,7 @@ "upscaling": "Upscaling", "upscale": "Upscale", "upscaleImage": "Upscale Image", + "denoisingStrength": "Denoising Strength", "scale": "Scale", "otherOptions": "Other Options", "seamlessTiling": "Seamless Tiling", @@ -46,9 +47,11 @@ "invoke": "Invoke", "cancel": "Cancel", "promptPlaceholder": "Type prompt here. [negative tokens], (upweight)++, (downweight)--, swap and blend are available (see docs)", + "negativePrompts": "Negative Prompts", "sendTo": "Send to", "sendToImg2Img": "Send to Image to Image", "sendToUnifiedCanvas": "Send To Unified Canvas", + "copyImage": "Copy Image", "copyImageToLink": "Copy Image To Link", "downloadImage": "Download Image", "openInViewer": "Open In Viewer", diff --git a/invokeai/frontend/src/app/invokeai.d.ts b/invokeai/frontend/src/app/invokeai.d.ts index 4043df9ddc..afa5fcec3d 100644 --- a/invokeai/frontend/src/app/invokeai.d.ts +++ b/invokeai/frontend/src/app/invokeai.d.ts @@ -225,6 +225,15 @@ export declare type InvokeModelConversionProps = { custom_location: string | null; }; +export declare type InvokeModelMergingProps = { + models_to_merge: string[]; + alpha: number; + interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'; + force: boolean; + merged_model_name: string; + model_merge_save_path: string | null; +}; + /** * These types type data received from the server via socketio. */ @@ -239,6 +248,12 @@ export declare type ModelConvertedResponse = { model_list: ModelList; }; +export declare type ModelsMergedResponse = { + merged_models: string[]; + merged_model_name: string; + model_list: ModelList; +}; + export declare type ModelAddedResponse = { new_model_name: string; model_list: ModelList; diff --git a/invokeai/frontend/src/app/socketio/actions.ts b/invokeai/frontend/src/app/socketio/actions.ts index e0a8dbc9e4..57758d1914 100644 --- a/invokeai/frontend/src/app/socketio/actions.ts +++ b/invokeai/frontend/src/app/socketio/actions.ts @@ -43,6 +43,11 @@ export const convertToDiffusers = 'socketio/convertToDiffusers' ); +export const mergeDiffusersModels = + createAction( + 'socketio/mergeDiffusersModels' + ); + export const requestModelChange = createAction( 'socketio/requestModelChange' ); diff --git a/invokeai/frontend/src/app/socketio/emitters.ts b/invokeai/frontend/src/app/socketio/emitters.ts index a01a7ecc6b..2aa1e03552 100644 --- a/invokeai/frontend/src/app/socketio/emitters.ts +++ b/invokeai/frontend/src/app/socketio/emitters.ts @@ -16,6 +16,7 @@ import { generationRequested, modelChangeRequested, modelConvertRequested, + modelMergingRequested, setIsProcessing, } from 'features/system/store/systemSlice'; import { InvokeTabName } from 'features/ui/store/tabMap'; @@ -185,6 +186,12 @@ const makeSocketIOEmitters = ( dispatch(modelConvertRequested()); socketio.emit('convertToDiffusers', modelToConvert); }, + emitMergeDiffusersModels: ( + modelMergeInfo: InvokeAI.InvokeModelMergingProps + ) => { + dispatch(modelMergingRequested()); + socketio.emit('mergeDiffusersModels', modelMergeInfo); + }, emitRequestModelChange: (modelName: string) => { dispatch(modelChangeRequested()); socketio.emit('requestModelChange', modelName); diff --git a/invokeai/frontend/src/app/socketio/listeners.ts b/invokeai/frontend/src/app/socketio/listeners.ts index bd40b6f5eb..ba28d941f6 100644 --- a/invokeai/frontend/src/app/socketio/listeners.ts +++ b/invokeai/frontend/src/app/socketio/listeners.ts @@ -432,6 +432,28 @@ const makeSocketIOListeners = ( }) ); }, + onModelsMerged: (data: InvokeAI.ModelsMergedResponse) => { + const { merged_models, merged_model_name, model_list } = data; + dispatch(setModelList(model_list)); + dispatch(setCurrentStatus(i18n.t('common:statusMergedModels'))); + dispatch(setIsProcessing(false)); + dispatch(setIsCancelable(true)); + dispatch( + addLogEntry({ + timestamp: dateFormat(new Date(), 'isoDateTime'), + message: `Models merged: ${merged_models}`, + level: 'info', + }) + ); + dispatch( + addToast({ + title: `${i18n.t('modelmanager:modelsMerged')}: ${merged_model_name}`, + status: 'success', + duration: 2500, + isClosable: true, + }) + ); + }, onModelChanged: (data: InvokeAI.ModelChangeResponse) => { const { model_name, model_list } = data; dispatch(setModelList(model_list)); diff --git a/invokeai/frontend/src/app/socketio/middleware.ts b/invokeai/frontend/src/app/socketio/middleware.ts index e08725adbc..a28e4edc80 100644 --- a/invokeai/frontend/src/app/socketio/middleware.ts +++ b/invokeai/frontend/src/app/socketio/middleware.ts @@ -49,6 +49,7 @@ export const socketioMiddleware = () => { onNewModelAdded, onModelDeleted, onModelConverted, + onModelsMerged, onModelChangeFailed, onTempFolderEmptied, } = makeSocketIOListeners(store); @@ -66,6 +67,7 @@ export const socketioMiddleware = () => { emitAddNewModel, emitDeleteModel, emitConvertToDiffusers, + emitMergeDiffusersModels, emitRequestModelChange, emitSaveStagingAreaImageToGallery, emitRequestEmptyTempFolder, @@ -131,6 +133,10 @@ export const socketioMiddleware = () => { onModelConverted(data); }); + socketio.on('modelsMerged', (data: InvokeAI.ModelsMergedResponse) => { + onModelsMerged(data); + }); + socketio.on('modelChanged', (data: InvokeAI.ModelChangeResponse) => { onModelChanged(data); }); @@ -210,6 +216,11 @@ export const socketioMiddleware = () => { break; } + case 'socketio/mergeDiffusersModels': { + emitMergeDiffusersModels(action.payload); + break; + } + case 'socketio/requestModelChange': { emitRequestModelChange(action.payload); break; diff --git a/invokeai/frontend/src/features/system/components/ModelManager/AddModel.tsx b/invokeai/frontend/src/features/system/components/ModelManager/AddModel.tsx index c24dff9f72..edfb955588 100644 --- a/invokeai/frontend/src/features/system/components/ModelManager/AddModel.tsx +++ b/invokeai/frontend/src/features/system/components/ModelManager/AddModel.tsx @@ -87,7 +87,11 @@ export default function AddModel() { closeOnOverlayClick={false} > - + {t('modelmanager:addNewModel')} diff --git a/invokeai/frontend/src/features/system/components/ModelManager/MergeModels.tsx b/invokeai/frontend/src/features/system/components/ModelManager/MergeModels.tsx new file mode 100644 index 0000000000..ce2661cb9b --- /dev/null +++ b/invokeai/frontend/src/features/system/components/ModelManager/MergeModels.tsx @@ -0,0 +1,292 @@ +import { + Flex, + Modal, + ModalCloseButton, + ModalContent, + ModalHeader, + ModalOverlay, + Radio, + RadioGroup, + Text, + Tooltip, + useDisclosure, +} from '@chakra-ui/react'; +import { mergeDiffusersModels } from 'app/socketio/actions'; +import { type RootState } from 'app/store'; +import { useAppDispatch, useAppSelector } from 'app/storeHooks'; +import IAIButton from 'common/components/IAIButton'; +import IAIInput from 'common/components/IAIInput'; +import IAISelect from 'common/components/IAISelect'; +import { diffusersModelsSelector } from 'features/system/store/systemSelectors'; +import { useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import * as InvokeAI from 'app/invokeai'; +import IAISlider from 'common/components/IAISlider'; +import IAICheckbox from 'common/components/IAICheckbox'; + +export default function MergeModels() { + const dispatch = useAppDispatch(); + + const { isOpen, onOpen, onClose } = useDisclosure(); + + const diffusersModels = useAppSelector(diffusersModelsSelector); + + const { t } = useTranslation(); + + const [modelOne, setModelOne] = useState( + Object.keys(diffusersModels)[0] + ); + const [modelTwo, setModelTwo] = useState( + Object.keys(diffusersModels)[1] + ); + const [modelThree, setModelThree] = useState('none'); + + const [mergedModelName, setMergedModelName] = useState(''); + const [modelMergeAlpha, setModelMergeAlpha] = useState(0.5); + + const [modelMergeInterp, setModelMergeInterp] = useState< + 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference' + >('weighted_sum'); + + const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState< + 'root' | 'custom' + >('root'); + + const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] = + useState(''); + + const [modelMergeForce, setModelMergeForce] = useState(false); + + const modelOneList = Object.keys(diffusersModels).filter((model) => { + if (model !== modelTwo && model !== modelThree) return model; + }); + + const modelTwoList = Object.keys(diffusersModels).filter((model) => { + if (model !== modelOne && model !== modelThree) return model; + }); + + const modelThreeList = [ + 'none', + ...Object.keys(diffusersModels).filter((model) => { + if (model !== modelOne && model !== modelTwo) return model; + }), + ]; + + const isProcessing = useAppSelector( + (state: RootState) => state.system.isProcessing + ); + + const mergeModelsHandler = () => { + let modelsToMerge: string[] = [modelOne, modelTwo, modelThree]; + modelsToMerge = modelsToMerge.filter((model) => model !== 'none'); + + const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = { + models_to_merge: modelsToMerge, + merged_model_name: + mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'), + alpha: modelMergeAlpha, + interp: modelMergeInterp, + model_merge_save_path: + modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc, + force: modelMergeForce, + }; + + dispatch(mergeDiffusersModels(mergeModelsInfo)); + }; + + return ( + <> + + + {t('modelmanager:mergeModels')} + + + + + + + {t('modelmanager:mergeModels')} + + + + {t('modelmanager:modelMergeHeaderHelp1')} + + {t('modelmanager:modelMergeHeaderHelp2')} + + + + setModelOne(e.target.value)} + /> + setModelTwo(e.target.value)} + /> + { + if (e.target.value !== 'none') { + setModelThree(e.target.value); + setModelMergeInterp('add_difference'); + } else { + setModelThree('none'); + setModelMergeInterp('weighted_sum'); + } + }} + /> + + + setMergedModelName(e.target.value)} + /> + + + setModelMergeAlpha(v)} + withInput + withReset + handleReset={() => setModelMergeAlpha(0.5)} + withSliderMarks + /> + + {t('modelmanager:modelMergeAlphaHelp')} + + + + + + {t('modelmanager:interpolationType')} + + setModelMergeInterp(v)} + > + + {modelThree === 'none' ? ( + <> + weighted_sum + sigmoid + inv_sigmoid + + ) : ( + + + add_difference + + + )} + + + + + + + + {t('modelmanager:mergedModelSaveLocation')} + + + setModelMergeSaveLocType(v) + } + > + + + {t('modelmanager:invokeAIFolder')} + + {t('modelmanager:custom')} + + + + + {modelMergeSaveLocType === 'custom' && ( + setModelMergeCustomSaveLoc(e.target.value)} + /> + )} + + + setModelMergeForce(e.target.checked)} + fontWeight="bold" + /> + + + {t('modelmanager:merge')} + + + + + + ); +} diff --git a/invokeai/frontend/src/features/system/components/ModelManager/ModelList.tsx b/invokeai/frontend/src/features/system/components/ModelManager/ModelList.tsx index 0e961bd3d8..5471711656 100644 --- a/invokeai/frontend/src/features/system/components/ModelManager/ModelList.tsx +++ b/invokeai/frontend/src/features/system/components/ModelManager/ModelList.tsx @@ -14,6 +14,7 @@ import { systemSelector } from 'features/system/store/systemSelectors'; import type { SystemState } from 'features/system/store/systemSlice'; import { isEqual, map } from 'lodash'; import type { ChangeEvent, ReactNode } from 'react'; +import MergeModels from './MergeModels'; const modelListSelector = createSelector( systemSelector, @@ -181,7 +182,10 @@ const ModelList = () => { {t('modelmanager:availableModels')} - + + + + state.system; @@ -28,3 +28,23 @@ export const activeModelSelector = createSelector( }, } ); + +export const diffusersModelsSelector = createSelector( + systemSelector, + (system) => { + const { model_list } = system; + + const diffusersModels = pickBy(model_list, (model, key) => { + if (model.format === 'diffusers') { + return { name: key, ...model }; + } + }); + + return diffusersModels; + }, + { + memoizeOptions: { + resultEqualityCheck: isEqual, + }, + } +); diff --git a/invokeai/frontend/src/features/system/store/systemSlice.ts b/invokeai/frontend/src/features/system/store/systemSlice.ts index e2f3355b31..e5476df5c8 100644 --- a/invokeai/frontend/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/src/features/system/store/systemSlice.ts @@ -220,6 +220,12 @@ export const systemSlice = createSlice({ state.isProcessing = true; state.currentStatusHasSteps = false; }, + modelMergingRequested: (state) => { + state.currentStatus = i18n.t('common:statusMergingModels'); + state.isCancelable = false; + state.isProcessing = true; + state.currentStatusHasSteps = false; + }, setSaveIntermediatesInterval: (state, action: PayloadAction) => { state.saveIntermediatesInterval = action.payload; }, @@ -272,6 +278,7 @@ export const { setIsCancelable, modelChangeRequested, modelConvertRequested, + modelMergingRequested, setSaveIntermediatesInterval, setEnableImageDebugging, generationRequested,