make model merge script work with new model manager

This commit is contained in:
Lincoln Stein
2023-09-30 12:24:39 -04:00
parent 807ae821ea
commit acaaff4b7e
4 changed files with 64 additions and 32 deletions

View File

@ -345,6 +345,7 @@ class ModelManagerService(ModelManagerServiceBase):
if self._event_bus:
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
self._loader = ModelLoad(config, **kwargs)
self._loader.installer.scan_models_directory() # synchronize new/deleted models found in models directory
def get_model(
self,

View File

@ -23,4 +23,5 @@ from .storage import ( # noqa F401
ModelConfigStoreSQL,
ModelConfigStoreYAML,
UnknownModelException,
get_config_store,
)

View File

@ -369,7 +369,8 @@ class ModelInstall(ModelInstallBase):
self._tmpdir = None
# this step synchronizes the `models` directory with the models db
self.scan_models_directory()
# do NOT do this automatically, but only on app startup
# self.scan_models_directory()
@property
def queue(self) -> DownloadQueueBase:

View File

@ -9,18 +9,32 @@ import curses
import sys
from argparse import Namespace
from pathlib import Path
from typing import List, Optional
from typing import Dict, List, Optional, Tuple
import npyscreen
from npyscreen import widget
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType
from invokeai.backend.model_manager import (
BaseModelType,
ModelConfigStore,
ModelFormat,
ModelType,
ModelVariantType,
get_config_store,
)
from invokeai.backend.model_manager.merge import ModelMerger
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
config = InvokeAIAppConfig.get_config()
BASE_TYPES = [
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
]
def _parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="InvokeAI model merging")
@ -48,7 +62,7 @@ def _parse_args() -> Namespace:
parser.add_argument(
"--base_model",
type=str,
choices=[x.value for x in BaseModelType],
choices=[x[0].value for x in BASE_TYPES],
help="The base model shared by the models to be merged",
)
parser.add_argument(
@ -106,9 +120,9 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
def create(self):
window_height, window_width = curses.initscr().getmaxyx()
self.model_names = self.get_model_names()
self.current_base = 0
self.models = self.get_models(BASE_TYPES[self.current_base][0])
self.model_names = [x[1] for x in self.models]
max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width * 3 < window_width
@ -128,10 +142,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
self.nextrely += 1
self.base_select = self.add_widget_intelligent(
SingleSelectColumns,
values=[
"Models Built on SD-1.x",
"Models Built on SD-2.x",
],
values=[x[1] for x in BASE_TYPES],
value=[self.current_base],
columns=4,
max_height=2,
@ -262,19 +273,19 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
sys.exit(0)
def marshall_arguments(self) -> dict:
model_names = self.model_names
model_keys = [x[0] for x in self.models]
models = [
model_names[self.model1.value[0]],
model_names[self.model2.value[0]],
model_keys[self.model1.value[0]],
model_keys[self.model2.value[0]],
]
if self.model3.value[0] > 0:
models.append(model_names[self.model3.value[0] - 1])
models.append(model_keys[self.model3.value[0] - 1])
interp = "add_difference"
else:
interp = self.interpolations[self.merge_method.value[0]]
args = dict(
model_names=models,
model_keys=models,
base_model=tuple(BaseModelType)[self.base_select.value[0]],
alpha=self.alpha.value,
interp=interp,
@ -309,17 +320,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else:
return True
def get_model_names(self, base_model: Optional[BaseModelType] = None) -> List[str]:
model_names = [
info["model_name"]
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
if info["model_format"] == "diffusers"
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
models = [
(x.key, x.name)
for x in self.model_manager.search_by_name(model_type=ModelType.Main, base_model=base_model)
if x.model_format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal")
]
return sorted(model_names)
return sorted(models, key=lambda x: x[1])
def _populate_models(self, value=None):
base_model = tuple(BaseModelType)[value[0]]
self.model_names = self.get_model_names(base_model)
def _populate_models(self, value: int):
base_model = BASE_TYPES[value[0]][0]
self.models = self.get_models(base_model)
self.model_names = [x[1] for x in self.models]
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
@ -331,7 +343,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self, model_manager: ModelManager):
def __init__(self, model_manager: ModelConfigStore):
super().__init__()
self.model_manager = model_manager
@ -341,13 +353,12 @@ class Mergeapp(npyscreen.NPSAppManaged):
def run_gui(args: Namespace):
model_manager = ModelManager(config.model_conf_path)
model_manager: ModelConfigStore = get_config_store(config.model_conf_path)
mergeapp = Mergeapp(model_manager)
mergeapp.run()
args = mergeapp.merge_arguments
merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**args)
merger = ModelMerger(model_manager, config)
merger.merge_diffusion_models_and_save(args)
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
@ -361,13 +372,31 @@ def run_cli(args: Namespace):
args.merged_model_name = "+".join(args.model_names)
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
model_manager = ModelManager(config.model_conf_path)
model_manager: ModelConfigStore = get_config_store(config.model_conf_path)
assert (
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
len(model_manager.search_by_name(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**vars(args))
model_keys = []
for name in args.model_names:
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
model_keys.append(name)
else:
models = model_manager.search_by_name(
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
)
assert len(models) > 0, f"{name}: Unknown model"
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
model_keys.append(models[0].key)
merger.merge_diffusion_models_and_save(
alpha=args.alpha,
model_keys=model_keys,
merged_model_name=args.merged_model_name,
interp=args.interp,
force=args.force,
)
logger.info(f'Models merged into new model: "{args.merged_model_name}".')