From acaaff4b7e8ceb518791d9e82579bc5a097442f9 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 30 Sep 2023 12:24:39 -0400 Subject: [PATCH] make model merge script work with new model manager --- .../app/services/model_manager_service.py | 1 + invokeai/backend/model_manager/__init__.py | 1 + invokeai/backend/model_manager/install.py | 3 +- invokeai/frontend/merge/merge_diffusers.py | 91 ++++++++++++------- 4 files changed, 64 insertions(+), 32 deletions(-) diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 64289d93ca..b35a076f3d 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -345,6 +345,7 @@ class ModelManagerService(ModelManagerServiceBase): if self._event_bus: kwargs.update(event_handlers=[self._event_bus.emit_model_event]) self._loader = ModelLoad(config, **kwargs) + self._loader.installer.scan_models_directory() # synchronize new/deleted models found in models directory def get_model( self, diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index e96ac1e668..8a51d75a5e 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -23,4 +23,5 @@ from .storage import ( # noqa F401 ModelConfigStoreSQL, ModelConfigStoreYAML, UnknownModelException, + get_config_store, ) diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index 8d425043f4..040eb2ceac 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -369,7 +369,8 @@ class ModelInstall(ModelInstallBase): self._tmpdir = None # this step synchronizes the `models` directory with the models db - self.scan_models_directory() + # do NOT do this automatically, but only on app startup + # self.scan_models_directory() @property def queue(self) -> DownloadQueueBase: diff --git a/invokeai/frontend/merge/merge_diffusers.py b/invokeai/frontend/merge/merge_diffusers.py index 8fa02cb49c..440f13de0b 100644 --- a/invokeai/frontend/merge/merge_diffusers.py +++ b/invokeai/frontend/merge/merge_diffusers.py @@ -9,18 +9,32 @@ import curses import sys from argparse import Namespace from pathlib import Path -from typing import List, Optional +from typing import Dict, List, Optional, Tuple import npyscreen from npyscreen import widget import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType +from invokeai.backend.model_manager import ( + BaseModelType, + ModelConfigStore, + ModelFormat, + ModelType, + ModelVariantType, + get_config_store, +) +from invokeai.backend.model_manager.merge import ModelMerger from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox config = InvokeAIAppConfig.get_config() +BASE_TYPES = [ + (BaseModelType.StableDiffusion1, "Models Built on SD-1.x"), + (BaseModelType.StableDiffusion2, "Models Built on SD-2.x"), + (BaseModelType.StableDiffusionXL, "Models Built on SDXL"), +] + def _parse_args() -> Namespace: parser = argparse.ArgumentParser(description="InvokeAI model merging") @@ -48,7 +62,7 @@ def _parse_args() -> Namespace: parser.add_argument( "--base_model", type=str, - choices=[x.value for x in BaseModelType], + choices=[x[0].value for x in BASE_TYPES], help="The base model shared by the models to be merged", ) parser.add_argument( @@ -106,9 +120,9 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): def create(self): window_height, window_width = curses.initscr().getmaxyx() - - self.model_names = self.get_model_names() self.current_base = 0 + self.models = self.get_models(BASE_TYPES[self.current_base][0]) + self.model_names = [x[1] for x in self.models] max_width = max([len(x) for x in self.model_names]) max_width += 6 horizontal_layout = max_width * 3 < window_width @@ -128,10 +142,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): self.nextrely += 1 self.base_select = self.add_widget_intelligent( SingleSelectColumns, - values=[ - "Models Built on SD-1.x", - "Models Built on SD-2.x", - ], + values=[x[1] for x in BASE_TYPES], value=[self.current_base], columns=4, max_height=2, @@ -262,19 +273,19 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): sys.exit(0) def marshall_arguments(self) -> dict: - model_names = self.model_names + model_keys = [x[0] for x in self.models] models = [ - model_names[self.model1.value[0]], - model_names[self.model2.value[0]], + model_keys[self.model1.value[0]], + model_keys[self.model2.value[0]], ] if self.model3.value[0] > 0: - models.append(model_names[self.model3.value[0] - 1]) + models.append(model_keys[self.model3.value[0] - 1]) interp = "add_difference" else: interp = self.interpolations[self.merge_method.value[0]] args = dict( - model_names=models, + model_keys=models, base_model=tuple(BaseModelType)[self.base_select.value[0]], alpha=self.alpha.value, interp=interp, @@ -309,17 +320,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): else: return True - def get_model_names(self, base_model: Optional[BaseModelType] = None) -> List[str]: - model_names = [ - info["model_name"] - for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model) - if info["model_format"] == "diffusers" + def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name + models = [ + (x.key, x.name) + for x in self.model_manager.search_by_name(model_type=ModelType.Main, base_model=base_model) + if x.model_format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal") ] - return sorted(model_names) + return sorted(models, key=lambda x: x[1]) - def _populate_models(self, value=None): - base_model = tuple(BaseModelType)[value[0]] - self.model_names = self.get_model_names(base_model) + def _populate_models(self, value: int): + base_model = BASE_TYPES[value[0]][0] + self.models = self.get_models(base_model) + self.model_names = [x[1] for x in self.models] models_plus_none = self.model_names.copy() models_plus_none.insert(0, "None") @@ -331,7 +343,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): class Mergeapp(npyscreen.NPSAppManaged): - def __init__(self, model_manager: ModelManager): + def __init__(self, model_manager: ModelConfigStore): super().__init__() self.model_manager = model_manager @@ -341,13 +353,12 @@ class Mergeapp(npyscreen.NPSAppManaged): def run_gui(args: Namespace): - model_manager = ModelManager(config.model_conf_path) + model_manager: ModelConfigStore = get_config_store(config.model_conf_path) mergeapp = Mergeapp(model_manager) mergeapp.run() - args = mergeapp.merge_arguments - merger = ModelMerger(model_manager) - merger.merge_diffusion_models_and_save(**args) + merger = ModelMerger(model_manager, config) + merger.merge_diffusion_models_and_save(args) logger.info(f'Models merged into new model: "{args["merged_model_name"]}".') @@ -361,13 +372,31 @@ def run_cli(args: Namespace): args.merged_model_name = "+".join(args.model_names) logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"') - model_manager = ModelManager(config.model_conf_path) + model_manager: ModelConfigStore = get_config_store(config.model_conf_path) assert ( - not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber + len(model_manager.search_by_name(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' merger = ModelMerger(model_manager) - merger.merge_diffusion_models_and_save(**vars(args)) + model_keys = [] + for name in args.model_names: + if len(name) == 32 and re.match(r"^[0-9a-f]$", name): + model_keys.append(name) + else: + models = model_manager.search_by_name( + model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model) + ) + assert len(models) > 0, f"{name}: Unknown model" + assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead." + model_keys.append(models[0].key) + + merger.merge_diffusion_models_and_save( + alpha=args.alpha, + model_keys=model_keys, + merged_model_name=args.merged_model_name, + interp=args.interp, + force=args.force, + ) logger.info(f'Models merged into new model: "{args.merged_model_name}".')