mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix issues identified during PR review by RyanjDick and brandonrising
- ModelMetadataStoreService is now injected into ModelRecordStoreService (these two services are really joined at the hip, and should someday be merged) - ModelRecordStoreService is now injected into ModelManagerService - Reduced timeout value for the various installer and download wait*() methods - Introduced a Mock modelmanager for testing - Removed bare print() statement with _logger in the install helper backend. - Removed unused code from model loader init file - Made `locker` a private variable in the `LoadedModel` object. - Fixed up model merge frontend (will be deprecated anyway!)
This commit is contained in:
committed by
psychedelicious
parent
f1597bd6da
commit
996eb96b4e
@ -6,20 +6,40 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
import argparse
|
||||
import curses
|
||||
import re
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import npyscreen
|
||||
from npyscreen import widget
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.model_install import ModelInstallService
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import ModelMerger
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
BASE_TYPES = [
|
||||
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
|
||||
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
|
||||
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
|
||||
]
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
@ -48,7 +68,7 @@ def _parse_args() -> Namespace:
|
||||
parser.add_argument(
|
||||
"--base_model",
|
||||
type=str,
|
||||
choices=[x.value for x in BaseModelType],
|
||||
choices=[x[0].value for x in BASE_TYPES],
|
||||
help="The base model shared by the models to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -98,17 +118,17 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
@property
|
||||
def model_manager(self):
|
||||
return self.parentApp.model_manager
|
||||
def record_store(self):
|
||||
return self.parentApp.record_store
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
|
||||
self.model_names = self.get_model_names()
|
||||
self.current_base = 0
|
||||
self.models = self.get_models(BASE_TYPES[self.current_base][0])
|
||||
self.model_names = [x[1] for x in self.models]
|
||||
max_width = max([len(x) for x in self.model_names])
|
||||
max_width += 6
|
||||
horizontal_layout = max_width * 3 < window_width
|
||||
@ -128,11 +148,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
self.nextrely += 1
|
||||
self.base_select = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=[
|
||||
"Models Built on SD-1.x",
|
||||
"Models Built on SD-2.x",
|
||||
"Models Built on SDXL",
|
||||
],
|
||||
values=[x[1] for x in BASE_TYPES],
|
||||
value=[self.current_base],
|
||||
columns=4,
|
||||
max_height=2,
|
||||
@ -263,21 +279,20 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
sys.exit(0)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
model_names = self.model_names
|
||||
model_keys = [x[0] for x in self.models]
|
||||
models = [
|
||||
model_names[self.model1.value[0]],
|
||||
model_names[self.model2.value[0]],
|
||||
model_keys[self.model1.value[0]],
|
||||
model_keys[self.model2.value[0]],
|
||||
]
|
||||
if self.model3.value[0] > 0:
|
||||
models.append(model_names[self.model3.value[0] - 1])
|
||||
models.append(model_keys[self.model3.value[0] - 1])
|
||||
interp = "add_difference"
|
||||
else:
|
||||
interp = self.interpolations[self.merge_method.value[0]]
|
||||
|
||||
bases = ["sd-1", "sd-2", "sdxl"]
|
||||
args = {
|
||||
"model_names": models,
|
||||
"base_model": BaseModelType(bases[self.base_select.value[0]]),
|
||||
"model_keys": models,
|
||||
"base_model": tuple(BaseModelType)[self.base_select.value[0]],
|
||||
"alpha": self.alpha.value,
|
||||
"interp": interp,
|
||||
"force": self.force.value,
|
||||
@ -311,18 +326,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self, base_model: BaseModelType = BaseModelType.StableDiffusion1) -> List[str]:
|
||||
model_names = [
|
||||
info["model_name"]
|
||||
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
||||
if info["model_format"] == "diffusers"
|
||||
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
|
||||
models = [
|
||||
(x.key, x.name)
|
||||
for x in self.record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model)
|
||||
if x.format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal")
|
||||
]
|
||||
return sorted(model_names)
|
||||
return sorted(models, key=lambda x: x[1])
|
||||
|
||||
def _populate_models(self, value=None):
|
||||
bases = ["sd-1", "sd-2", "sdxl"]
|
||||
base_model = BaseModelType(bases[value[0]])
|
||||
self.model_names = self.get_model_names(base_model)
|
||||
def _populate_models(self, value: List[int]):
|
||||
base_model = BASE_TYPES[value[0]][0]
|
||||
self.models = self.get_models(base_model)
|
||||
self.model_names = [x[1] for x in self.models]
|
||||
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
@ -334,24 +349,24 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self, model_manager: ModelManager):
|
||||
def __init__(self, record_store: ModelRecordServiceBase):
|
||||
super().__init__()
|
||||
self.model_manager = model_manager
|
||||
self.record_store = record_store
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
||||
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
|
||||
|
||||
|
||||
def run_gui(args: Namespace):
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
mergeapp = Mergeapp(model_manager)
|
||||
def run_gui(args: Namespace) -> None:
|
||||
record_store: ModelRecordServiceBase = get_config_store()
|
||||
mergeapp = Mergeapp(record_store)
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
merger = ModelMerger(model_manager)
|
||||
merger = get_model_merger(record_store)
|
||||
merger.merge_diffusion_models_and_save(**args)
|
||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
||||
merged_model_name = args["merged_model_name"]
|
||||
logger.info(f'Models merged into new model: "{merged_model_name}".')
|
||||
|
||||
|
||||
def run_cli(args: Namespace):
|
||||
@ -364,20 +379,54 @@ def run_cli(args: Namespace):
|
||||
args.merged_model_name = "+".join(args.model_names)
|
||||
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
|
||||
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
record_store: ModelRecordServiceBase = get_config_store()
|
||||
assert (
|
||||
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
|
||||
len(record_store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merger = ModelMerger(model_manager)
|
||||
merger.merge_diffusion_models_and_save(**vars(args))
|
||||
merger = get_model_merger(record_store)
|
||||
model_keys = []
|
||||
for name in args.model_names:
|
||||
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
|
||||
model_keys.append(name)
|
||||
else:
|
||||
models = record_store.search_by_attr(
|
||||
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
|
||||
)
|
||||
assert len(models) > 0, f"{name}: Unknown model"
|
||||
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
|
||||
model_keys.append(models[0].key)
|
||||
|
||||
merger.merge_diffusion_models_and_save(
|
||||
alpha=args.alpha,
|
||||
model_keys=model_keys,
|
||||
merged_model_name=args.merged_model_name,
|
||||
interp=args.interp,
|
||||
force=args.force,
|
||||
)
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
|
||||
|
||||
def get_config_store() -> ModelRecordServiceSQL:
|
||||
output_path = config.output_path
|
||||
assert output_path is not None
|
||||
image_files = DiskImageFileStorage(output_path / "images")
|
||||
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
|
||||
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
|
||||
|
||||
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:
|
||||
installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=DownloadQueueService())
|
||||
installer.start()
|
||||
return ModelMerger(installer)
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
if args.root_dir:
|
||||
config.parse_args(["--root", str(args.root_dir)])
|
||||
else:
|
||||
config.parse_args([])
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
|
Reference in New Issue
Block a user