mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make model merge script work with new model manager
This commit is contained in:
@ -345,6 +345,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
if self._event_bus:
|
||||
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
|
||||
self._loader = ModelLoad(config, **kwargs)
|
||||
self._loader.installer.scan_models_directory() # synchronize new/deleted models found in models directory
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
|
@ -23,4 +23,5 @@ from .storage import ( # noqa F401
|
||||
ModelConfigStoreSQL,
|
||||
ModelConfigStoreYAML,
|
||||
UnknownModelException,
|
||||
get_config_store,
|
||||
)
|
||||
|
@ -369,7 +369,8 @@ class ModelInstall(ModelInstallBase):
|
||||
self._tmpdir = None
|
||||
|
||||
# this step synchronizes the `models` directory with the models db
|
||||
self.scan_models_directory()
|
||||
# do NOT do this automatically, but only on app startup
|
||||
# self.scan_models_directory()
|
||||
|
||||
@property
|
||||
def queue(self) -> DownloadQueueBase:
|
||||
|
@ -9,18 +9,32 @@ import curses
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import npyscreen
|
||||
from npyscreen import widget
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
ModelConfigStore,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
get_config_store,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import ModelMerger
|
||||
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
BASE_TYPES = [
|
||||
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
|
||||
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
|
||||
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
|
||||
]
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||
@ -48,7 +62,7 @@ def _parse_args() -> Namespace:
|
||||
parser.add_argument(
|
||||
"--base_model",
|
||||
type=str,
|
||||
choices=[x.value for x in BaseModelType],
|
||||
choices=[x[0].value for x in BASE_TYPES],
|
||||
help="The base model shared by the models to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -106,9 +120,9 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
def create(self):
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
|
||||
self.model_names = self.get_model_names()
|
||||
self.current_base = 0
|
||||
self.models = self.get_models(BASE_TYPES[self.current_base][0])
|
||||
self.model_names = [x[1] for x in self.models]
|
||||
max_width = max([len(x) for x in self.model_names])
|
||||
max_width += 6
|
||||
horizontal_layout = max_width * 3 < window_width
|
||||
@ -128,10 +142,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
self.nextrely += 1
|
||||
self.base_select = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=[
|
||||
"Models Built on SD-1.x",
|
||||
"Models Built on SD-2.x",
|
||||
],
|
||||
values=[x[1] for x in BASE_TYPES],
|
||||
value=[self.current_base],
|
||||
columns=4,
|
||||
max_height=2,
|
||||
@ -262,19 +273,19 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
sys.exit(0)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
model_names = self.model_names
|
||||
model_keys = [x[0] for x in self.models]
|
||||
models = [
|
||||
model_names[self.model1.value[0]],
|
||||
model_names[self.model2.value[0]],
|
||||
model_keys[self.model1.value[0]],
|
||||
model_keys[self.model2.value[0]],
|
||||
]
|
||||
if self.model3.value[0] > 0:
|
||||
models.append(model_names[self.model3.value[0] - 1])
|
||||
models.append(model_keys[self.model3.value[0] - 1])
|
||||
interp = "add_difference"
|
||||
else:
|
||||
interp = self.interpolations[self.merge_method.value[0]]
|
||||
|
||||
args = dict(
|
||||
model_names=models,
|
||||
model_keys=models,
|
||||
base_model=tuple(BaseModelType)[self.base_select.value[0]],
|
||||
alpha=self.alpha.value,
|
||||
interp=interp,
|
||||
@ -309,17 +320,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self, base_model: Optional[BaseModelType] = None) -> List[str]:
|
||||
model_names = [
|
||||
info["model_name"]
|
||||
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
||||
if info["model_format"] == "diffusers"
|
||||
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
|
||||
models = [
|
||||
(x.key, x.name)
|
||||
for x in self.model_manager.search_by_name(model_type=ModelType.Main, base_model=base_model)
|
||||
if x.model_format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal")
|
||||
]
|
||||
return sorted(model_names)
|
||||
return sorted(models, key=lambda x: x[1])
|
||||
|
||||
def _populate_models(self, value=None):
|
||||
base_model = tuple(BaseModelType)[value[0]]
|
||||
self.model_names = self.get_model_names(base_model)
|
||||
def _populate_models(self, value: int):
|
||||
base_model = BASE_TYPES[value[0]][0]
|
||||
self.models = self.get_models(base_model)
|
||||
self.model_names = [x[1] for x in self.models]
|
||||
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
@ -331,7 +343,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self, model_manager: ModelManager):
|
||||
def __init__(self, model_manager: ModelConfigStore):
|
||||
super().__init__()
|
||||
self.model_manager = model_manager
|
||||
|
||||
@ -341,13 +353,12 @@ class Mergeapp(npyscreen.NPSAppManaged):
|
||||
|
||||
|
||||
def run_gui(args: Namespace):
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
model_manager: ModelConfigStore = get_config_store(config.model_conf_path)
|
||||
mergeapp = Mergeapp(model_manager)
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
merger = ModelMerger(model_manager)
|
||||
merger.merge_diffusion_models_and_save(**args)
|
||||
merger = ModelMerger(model_manager, config)
|
||||
merger.merge_diffusion_models_and_save(args)
|
||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
||||
|
||||
|
||||
@ -361,13 +372,31 @@ def run_cli(args: Namespace):
|
||||
args.merged_model_name = "+".join(args.model_names)
|
||||
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
|
||||
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
model_manager: ModelConfigStore = get_config_store(config.model_conf_path)
|
||||
assert (
|
||||
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
|
||||
len(model_manager.search_by_name(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merger = ModelMerger(model_manager)
|
||||
merger.merge_diffusion_models_and_save(**vars(args))
|
||||
model_keys = []
|
||||
for name in args.model_names:
|
||||
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
|
||||
model_keys.append(name)
|
||||
else:
|
||||
models = model_manager.search_by_name(
|
||||
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
|
||||
)
|
||||
assert len(models) > 0, f"{name}: Unknown model"
|
||||
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
|
||||
model_keys.append(models[0].key)
|
||||
|
||||
merger.merge_diffusion_models_and_save(
|
||||
alpha=args.alpha,
|
||||
model_keys=model_keys,
|
||||
merged_model_name=args.merged_model_name,
|
||||
interp=args.interp,
|
||||
force=args.force,
|
||||
)
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user