Fix issues identified during PR review by RyanjDick and brandonrising

- ModelMetadataStoreService is now injected into ModelRecordStoreService
  (these two services are really joined at the hip, and should someday be merged)
- ModelRecordStoreService is now injected into ModelManagerService
- Reduced timeout value for the various installer and download wait*() methods
- Introduced a Mock modelmanager for testing
- Removed bare print() statement with _logger in the install helper backend.
- Removed unused code from model loader init file
- Made `locker` a private variable in the `LoadedModel` object.
- Fixed up model merge frontend (will be deprecated anyway!)
This commit is contained in:
Lincoln Stein
2024-02-15 22:41:29 -05:00
committed by psychedelicious
parent f1597bd6da
commit 996eb96b4e
22 changed files with 449 additions and 131 deletions

View File

@ -6,20 +6,40 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
import curses
import re
import sys
from argparse import Namespace
from pathlib import Path
from typing import List
from typing import List, Optional, Tuple
import npyscreen
from npyscreen import widget
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
BaseModelType,
ModelFormat,
ModelType,
ModelVariantType,
)
from invokeai.backend.model_manager.merge import ModelMerger
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger()
BASE_TYPES = [
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
]
def _parse_args() -> Namespace:
@ -48,7 +68,7 @@ def _parse_args() -> Namespace:
parser.add_argument(
"--base_model",
type=str,
choices=[x.value for x in BaseModelType],
choices=[x[0].value for x in BASE_TYPES],
help="The base model shared by the models to be merged",
)
parser.add_argument(
@ -98,17 +118,17 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
super().__init__(parentApp, name)
@property
def model_manager(self):
return self.parentApp.model_manager
def record_store(self):
return self.parentApp.record_store
def afterEditing(self):
self.parentApp.setNextForm(None)
def create(self):
window_height, window_width = curses.initscr().getmaxyx()
self.model_names = self.get_model_names()
self.current_base = 0
self.models = self.get_models(BASE_TYPES[self.current_base][0])
self.model_names = [x[1] for x in self.models]
max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width * 3 < window_width
@ -128,11 +148,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
self.nextrely += 1
self.base_select = self.add_widget_intelligent(
SingleSelectColumns,
values=[
"Models Built on SD-1.x",
"Models Built on SD-2.x",
"Models Built on SDXL",
],
values=[x[1] for x in BASE_TYPES],
value=[self.current_base],
columns=4,
max_height=2,
@ -263,21 +279,20 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
sys.exit(0)
def marshall_arguments(self) -> dict:
model_names = self.model_names
model_keys = [x[0] for x in self.models]
models = [
model_names[self.model1.value[0]],
model_names[self.model2.value[0]],
model_keys[self.model1.value[0]],
model_keys[self.model2.value[0]],
]
if self.model3.value[0] > 0:
models.append(model_names[self.model3.value[0] - 1])
models.append(model_keys[self.model3.value[0] - 1])
interp = "add_difference"
else:
interp = self.interpolations[self.merge_method.value[0]]
bases = ["sd-1", "sd-2", "sdxl"]
args = {
"model_names": models,
"base_model": BaseModelType(bases[self.base_select.value[0]]),
"model_keys": models,
"base_model": tuple(BaseModelType)[self.base_select.value[0]],
"alpha": self.alpha.value,
"interp": interp,
"force": self.force.value,
@ -311,18 +326,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else:
return True
def get_model_names(self, base_model: BaseModelType = BaseModelType.StableDiffusion1) -> List[str]:
model_names = [
info["model_name"]
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
if info["model_format"] == "diffusers"
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
models = [
(x.key, x.name)
for x in self.record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model)
if x.format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal")
]
return sorted(model_names)
return sorted(models, key=lambda x: x[1])
def _populate_models(self, value=None):
bases = ["sd-1", "sd-2", "sdxl"]
base_model = BaseModelType(bases[value[0]])
self.model_names = self.get_model_names(base_model)
def _populate_models(self, value: List[int]):
base_model = BASE_TYPES[value[0]][0]
self.models = self.get_models(base_model)
self.model_names = [x[1] for x in self.models]
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
@ -334,24 +349,24 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self, model_manager: ModelManager):
def __init__(self, record_store: ModelRecordServiceBase):
super().__init__()
self.model_manager = model_manager
self.record_store = record_store
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
def run_gui(args: Namespace):
model_manager = ModelManager(config.model_conf_path)
mergeapp = Mergeapp(model_manager)
def run_gui(args: Namespace) -> None:
record_store: ModelRecordServiceBase = get_config_store()
mergeapp = Mergeapp(record_store)
mergeapp.run()
args = mergeapp.merge_arguments
merger = ModelMerger(model_manager)
merger = get_model_merger(record_store)
merger.merge_diffusion_models_and_save(**args)
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
merged_model_name = args["merged_model_name"]
logger.info(f'Models merged into new model: "{merged_model_name}".')
def run_cli(args: Namespace):
@ -364,20 +379,54 @@ def run_cli(args: Namespace):
args.merged_model_name = "+".join(args.model_names)
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
model_manager = ModelManager(config.model_conf_path)
record_store: ModelRecordServiceBase = get_config_store()
assert (
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
len(record_store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**vars(args))
merger = get_model_merger(record_store)
model_keys = []
for name in args.model_names:
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
model_keys.append(name)
else:
models = record_store.search_by_attr(
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
)
assert len(models) > 0, f"{name}: Unknown model"
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
model_keys.append(models[0].key)
merger.merge_diffusion_models_and_save(
alpha=args.alpha,
model_keys=model_keys,
merged_model_name=args.merged_model_name,
interp=args.interp,
force=args.force,
)
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
def get_config_store() -> ModelRecordServiceSQL:
output_path = config.output_path
assert output_path is not None
image_files = DiskImageFileStorage(output_path / "images")
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:
installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=DownloadQueueService())
installer.start()
return ModelMerger(installer)
def main():
args = _parse_args()
if args.root_dir:
config.parse_args(["--root", str(args.root_dir)])
else:
config.parse_args([])
try:
if args.front_end: