mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model merge backend, CLI and TUI working
This commit is contained in:
parent
f7daa6e71d
commit
ec7c2f07c6
@ -4,5 +4,5 @@ Initialization file for invokeai.backend.model_management
|
|||||||
from .model_manager import ModelManager, ModelInfo, AddModelResult
|
from .model_manager import ModelManager, ModelInfo, AddModelResult
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||||
from .model_merge import merge_diffusion_models_and_save, MergeInterpolationMethod
|
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||||
|
|
||||||
|
@ -279,7 +279,7 @@ class InvalidModelError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
class AddModelResult(BaseModel):
|
class AddModelResult(BaseModel):
|
||||||
name: str = Field(description="The name of the model after import")
|
name: str = Field(description="The name of the model after installation")
|
||||||
model_type: ModelType = Field(description="The type of model")
|
model_type: ModelType = Field(description="The type of model")
|
||||||
base_model: BaseModelType = Field(description="The base model")
|
base_model: BaseModelType = Field(description="The base model")
|
||||||
config: ModelConfigBase = Field(description="The configuration of the model")
|
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||||
@ -496,7 +496,7 @@ class ModelManager(object):
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
)->dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Returns a dict describing one installed model, using
|
Returns a dict describing one installed model, using
|
||||||
the combined format of the list_models() method.
|
the combined format of the list_models() method.
|
||||||
|
@ -11,109 +11,119 @@ from enum import Enum
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers import logging as dlogging
|
from diffusers import logging as dlogging
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
|
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||||
|
|
||||||
class MergeInterpolationMethod(str, Enum):
|
class MergeInterpolationMethod(str, Enum):
|
||||||
Sigmoid = "sigmoid"
|
Sigmoid = "sigmoid"
|
||||||
InvSigmoid = "inv_sigmoid"
|
InvSigmoid = "inv_sigmoid"
|
||||||
AddDifference = "add_difference"
|
AddDifference = "add_difference"
|
||||||
|
WeightedSum = "weighted_sum"
|
||||||
|
|
||||||
def merge_diffusion_models(
|
class ModelMerger(object):
|
||||||
model_paths: List[Path],
|
def __init__(self, manager: ModelManager):
|
||||||
alpha: float = 0.5,
|
self.manager = manager
|
||||||
interp: MergeInterpolationMethod = None,
|
|
||||||
force: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> DiffusionPipeline:
|
|
||||||
"""
|
|
||||||
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
|
||||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
|
||||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
|
||||||
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
|
||||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
|
||||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
|
||||||
|
|
||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
def merge_diffusion_models(
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
self,
|
||||||
"""
|
model_paths: List[Path],
|
||||||
with warnings.catch_warnings():
|
alpha: float = 0.5,
|
||||||
warnings.simplefilter("ignore")
|
interp: MergeInterpolationMethod = None,
|
||||||
verbosity = dlogging.get_verbosity()
|
force: bool = False,
|
||||||
dlogging.set_verbosity_error()
|
|
||||||
|
|
||||||
pipe = DiffusionPipeline.from_pretrained(
|
|
||||||
model_paths[0],
|
|
||||||
custom_pipeline="checkpoint_merger",
|
|
||||||
)
|
|
||||||
merged_pipe = pipe.merge(
|
|
||||||
pretrained_model_name_or_path_list=model_paths,
|
|
||||||
alpha=alpha,
|
|
||||||
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
|
|
||||||
force=force,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
) -> DiffusionPipeline:
|
||||||
|
"""
|
||||||
|
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
||||||
|
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||||
|
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||||
|
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||||
|
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||||
|
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||||
|
|
||||||
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
|
"""
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
verbosity = dlogging.get_verbosity()
|
||||||
|
dlogging.set_verbosity_error()
|
||||||
|
|
||||||
|
pipe = DiffusionPipeline.from_pretrained(
|
||||||
|
model_paths[0],
|
||||||
|
custom_pipeline="checkpoint_merger",
|
||||||
|
)
|
||||||
|
merged_pipe = pipe.merge(
|
||||||
|
pretrained_model_name_or_path_list=model_paths,
|
||||||
|
alpha=alpha,
|
||||||
|
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
|
||||||
|
force=force,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
dlogging.set_verbosity(verbosity)
|
||||||
|
return merged_pipe
|
||||||
|
|
||||||
|
|
||||||
|
def merge_diffusion_models_and_save (
|
||||||
|
self,
|
||||||
|
model_names: List[str],
|
||||||
|
base_model: Union[BaseModelType,str],
|
||||||
|
merged_model_name: str,
|
||||||
|
alpha: float = 0.5,
|
||||||
|
interp: MergeInterpolationMethod = None,
|
||||||
|
force: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
||||||
|
:param base_model: base model (must be the same for all merged models!)
|
||||||
|
:param merged_model_name: name for new model
|
||||||
|
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||||
|
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||||
|
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||||
|
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||||
|
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||||
|
|
||||||
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
|
"""
|
||||||
|
model_paths = list()
|
||||||
|
config = self.manager.app_config
|
||||||
|
base_model = BaseModelType(base_model)
|
||||||
|
vae = None
|
||||||
|
|
||||||
|
for mod in model_names:
|
||||||
|
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
|
||||||
|
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
||||||
|
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
|
||||||
|
assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged")
|
||||||
|
# pick up the first model's vae
|
||||||
|
if mod == model_names[0]:
|
||||||
|
vae = info.get("vae")
|
||||||
|
model_paths.extend([config.root_path / info["path"]])
|
||||||
|
|
||||||
|
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
|
||||||
|
merged_pipe = self.merge_diffusion_models(
|
||||||
|
model_paths, alpha, merge_method, force, **kwargs
|
||||||
)
|
)
|
||||||
dlogging.set_verbosity(verbosity)
|
dump_path = config.models_path / base_model.value / ModelType.Main.value
|
||||||
return merged_pipe
|
dump_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
dump_path = dump_path / merged_model_name
|
||||||
|
|
||||||
|
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||||
def merge_diffusion_models_and_save (
|
attributes = dict(
|
||||||
models: List["str"],
|
path = str(dump_path),
|
||||||
base_model: BaseModelType,
|
description = f"Merge of models {', '.join(model_names)}",
|
||||||
merged_model_name: str,
|
model_format = "diffusers",
|
||||||
config: InvokeAIAppConfig,
|
variant = ModelVariantType.Normal.value,
|
||||||
alpha: float = 0.5,
|
vae = vae,
|
||||||
interp: MergeInterpolationMethod = None,
|
)
|
||||||
force: bool = False,
|
return self.manager.add_model(merged_model_name,
|
||||||
**kwargs,
|
base_model = base_model,
|
||||||
):
|
model_type = ModelType.Main,
|
||||||
"""
|
model_attributes = attributes,
|
||||||
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
clobber = True
|
||||||
:param base_model: base model (must be the same for all merged models!)
|
)
|
||||||
:param merged_model_name: name for new model
|
|
||||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
|
||||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
|
||||||
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
|
||||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
|
||||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
|
||||||
|
|
||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
|
||||||
"""
|
|
||||||
model_manager = ModelManager(config.model_conf_path)
|
|
||||||
model_paths = list()
|
|
||||||
vae = None
|
|
||||||
for mod in models:
|
|
||||||
info = model_manager.model_info(mod, base_model=base_model, model_type=ModelType.main)
|
|
||||||
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
|
||||||
assert info["format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
|
|
||||||
assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged")
|
|
||||||
if mod == models[0]:
|
|
||||||
vae = info["vae"]
|
|
||||||
model_paths.extend([info["path"]])
|
|
||||||
|
|
||||||
merged_pipe = merge_diffusion_models(
|
|
||||||
model_paths, alpha, interp, force, **kwargs
|
|
||||||
)
|
|
||||||
dump_path = config.models_path / base_model.value / ModelType.main.value
|
|
||||||
dump_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
dump_path = dump_path / merged_model_name
|
|
||||||
|
|
||||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
|
||||||
attributes = dict(
|
|
||||||
path = dump_path,
|
|
||||||
description = f"Merge of models {', '.join(models)}",
|
|
||||||
model_format = "diffusers",
|
|
||||||
variant = ModelVariantType.Normal.value,
|
|
||||||
vae = vae,
|
|
||||||
)
|
|
||||||
model_manager.add_model(merged_model_name,
|
|
||||||
base_model = base_model,
|
|
||||||
model_type = ModelType.Main,
|
|
||||||
model_attributes = attributes,
|
|
||||||
clobber = True
|
|
||||||
)
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.frontend.merge
|
Initialization file for invokeai.frontend.merge
|
||||||
"""
|
"""
|
||||||
from .merge_diffusers import main as invokeai_merge_diffusers, merge_diffusion_models
|
from .merge_diffusers import main as invokeai_merge_diffusers
|
||||||
|
|
||||||
|
@ -6,10 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
|||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import curses
|
import curses
|
||||||
import enum
|
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
@ -21,12 +18,12 @@ from npyscreen import widget
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ...backend.model_management import (
|
from invokeai.backend.model_management import (
|
||||||
merge_diffusion_models_and_save,
|
ModelMerger, MergeInterpolationMethod,
|
||||||
ModelManager, MergeInterpolationMethod, BaseModelType
|
ModelManager, ModelType, BaseModelType,
|
||||||
)
|
)
|
||||||
from ...frontend.install.widgets import FloatTitleSlider
|
from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
@ -48,12 +45,13 @@ def _parse_args() -> Namespace:
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--models",
|
"--models",
|
||||||
|
dest="model_names",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help="Two to three model names to be merged",
|
help="Two to three model names to be merged",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base_type",
|
"--base_model",
|
||||||
type=str,
|
type=str,
|
||||||
choices=[x.value for x in BaseModelType],
|
choices=[x.value for x in BaseModelType],
|
||||||
help="The base model shared by the models to be merged",
|
help="The base model shared by the models to be merged",
|
||||||
@ -115,6 +113,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
window_height, window_width = curses.initscr().getmaxyx()
|
window_height, window_width = curses.initscr().getmaxyx()
|
||||||
|
|
||||||
self.model_names = self.get_model_names()
|
self.model_names = self.get_model_names()
|
||||||
|
self.current_base = 0
|
||||||
max_width = max([len(x) for x in self.model_names])
|
max_width = max([len(x) for x in self.model_names])
|
||||||
max_width += 6
|
max_width += 6
|
||||||
horizontal_layout = max_width * 3 < window_width
|
horizontal_layout = max_width * 3 < window_width
|
||||||
@ -131,12 +130,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
|
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
|
||||||
editable=False,
|
editable=False,
|
||||||
)
|
)
|
||||||
|
self.nextrely += 1
|
||||||
|
self.base_select = self.add_widget_intelligent(
|
||||||
|
SingleSelectColumns,
|
||||||
|
values=[
|
||||||
|
'Models Built on SD-1.x',
|
||||||
|
'Models Built on SD-2.x',
|
||||||
|
],
|
||||||
|
value=[self.current_base],
|
||||||
|
columns = 4,
|
||||||
|
max_height = 2,
|
||||||
|
relx=8,
|
||||||
|
scroll_exit = True,
|
||||||
|
)
|
||||||
|
self.base_select.on_changed = self._populate_models
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.FixedText,
|
npyscreen.FixedText,
|
||||||
value="MODEL 1",
|
value="MODEL 1",
|
||||||
color="GOOD",
|
color="GOOD",
|
||||||
editable=False,
|
editable=False,
|
||||||
rely=4 if horizontal_layout else None,
|
rely=6 if horizontal_layout else None,
|
||||||
)
|
)
|
||||||
self.model1 = self.add_widget_intelligent(
|
self.model1 = self.add_widget_intelligent(
|
||||||
npyscreen.SelectOne,
|
npyscreen.SelectOne,
|
||||||
@ -145,7 +158,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
max_height=len(self.model_names),
|
max_height=len(self.model_names),
|
||||||
max_width=max_width,
|
max_width=max_width,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
rely=5,
|
rely=7,
|
||||||
)
|
)
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.FixedText,
|
npyscreen.FixedText,
|
||||||
@ -153,7 +166,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
color="GOOD",
|
color="GOOD",
|
||||||
editable=False,
|
editable=False,
|
||||||
relx=max_width + 3 if horizontal_layout else None,
|
relx=max_width + 3 if horizontal_layout else None,
|
||||||
rely=4 if horizontal_layout else None,
|
rely=6 if horizontal_layout else None,
|
||||||
)
|
)
|
||||||
self.model2 = self.add_widget_intelligent(
|
self.model2 = self.add_widget_intelligent(
|
||||||
npyscreen.SelectOne,
|
npyscreen.SelectOne,
|
||||||
@ -163,7 +176,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
max_height=len(self.model_names),
|
max_height=len(self.model_names),
|
||||||
max_width=max_width,
|
max_width=max_width,
|
||||||
relx=max_width + 3 if horizontal_layout else None,
|
relx=max_width + 3 if horizontal_layout else None,
|
||||||
rely=5 if horizontal_layout else None,
|
rely=7 if horizontal_layout else None,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
@ -172,7 +185,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
color="GOOD",
|
color="GOOD",
|
||||||
editable=False,
|
editable=False,
|
||||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||||
rely=4 if horizontal_layout else None,
|
rely=6 if horizontal_layout else None,
|
||||||
)
|
)
|
||||||
models_plus_none = self.model_names.copy()
|
models_plus_none = self.model_names.copy()
|
||||||
models_plus_none.insert(0, "None")
|
models_plus_none.insert(0, "None")
|
||||||
@ -185,24 +198,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
max_width=max_width,
|
max_width=max_width,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||||
rely=5 if horizontal_layout else None,
|
rely=7 if horizontal_layout else None,
|
||||||
)
|
)
|
||||||
for m in [self.model1, self.model2, self.model3]:
|
for m in [self.model1, self.model2, self.model3]:
|
||||||
m.when_value_edited = self.models_changed
|
m.when_value_edited = self.models_changed
|
||||||
self.merged_model_name = self.add_widget_intelligent(
|
self.merged_model_name = self.add_widget_intelligent(
|
||||||
npyscreen.TitleText,
|
TextBox,
|
||||||
name="Name for merged model:",
|
name="Name for merged model:",
|
||||||
labelColor="CONTROL",
|
labelColor="CONTROL",
|
||||||
|
max_height=3,
|
||||||
value="",
|
value="",
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.force = self.add_widget_intelligent(
|
self.force = self.add_widget_intelligent(
|
||||||
npyscreen.Checkbox,
|
npyscreen.Checkbox,
|
||||||
name="Force merge of incompatible models",
|
name="Force merge of models created by different diffusers library versions",
|
||||||
labelColor="CONTROL",
|
labelColor="CONTROL",
|
||||||
value=False,
|
value=True,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
|
self.nextrely += 1
|
||||||
self.merge_method = self.add_widget_intelligent(
|
self.merge_method = self.add_widget_intelligent(
|
||||||
npyscreen.TitleSelectOne,
|
npyscreen.TitleSelectOne,
|
||||||
name="Merge Method:",
|
name="Merge Method:",
|
||||||
@ -264,7 +279,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
interp = self.interpolations[self.merge_method.value[0]]
|
interp = self.interpolations[self.merge_method.value[0]]
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
models=models,
|
model_names=models,
|
||||||
|
base_model=tuple(BaseModelType)[self.base_select.value[0]],
|
||||||
alpha=self.alpha.value,
|
alpha=self.alpha.value,
|
||||||
interp=interp,
|
interp=interp,
|
||||||
force=self.force.value,
|
force=self.force.value,
|
||||||
@ -302,21 +318,30 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_model_names(self) -> List[str]:
|
def get_model_names(self, base_model: BaseModelType=None) -> List[str]:
|
||||||
model_names = [
|
model_names = [
|
||||||
name
|
info["name"]
|
||||||
for name in self.model_manager.model_names()
|
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
||||||
if self.model_manager.model_info(name).get("format") == "diffusers"
|
if info["model_format"] == "diffusers"
|
||||||
]
|
]
|
||||||
return sorted(model_names)
|
return sorted(model_names)
|
||||||
|
|
||||||
|
def _populate_models(self,value=None):
|
||||||
|
base_model = tuple(BaseModelType)[value[0]]
|
||||||
|
self.model_names = self.get_model_names(base_model)
|
||||||
|
|
||||||
|
models_plus_none = self.model_names.copy()
|
||||||
|
models_plus_none.insert(0, "None")
|
||||||
|
self.model1.values = self.model_names
|
||||||
|
self.model2.values = self.model_names
|
||||||
|
self.model3.values = models_plus_none
|
||||||
|
|
||||||
|
self.display()
|
||||||
|
|
||||||
class Mergeapp(npyscreen.NPSAppManaged):
|
class Mergeapp(npyscreen.NPSAppManaged):
|
||||||
def __init__(self):
|
def __init__(self, model_manager:ModelManager):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
conf = OmegaConf.load(config.model_conf_path)
|
self.model_manager = model_manager
|
||||||
self.model_manager = ModelManager(
|
|
||||||
conf, "cpu", "float16"
|
|
||||||
) # precision doesn't really matter here
|
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
||||||
@ -324,38 +349,41 @@ class Mergeapp(npyscreen.NPSAppManaged):
|
|||||||
|
|
||||||
|
|
||||||
def run_gui(args: Namespace):
|
def run_gui(args: Namespace):
|
||||||
mergeapp = Mergeapp()
|
model_manager = ModelManager(config.model_conf_path)
|
||||||
|
mergeapp = Mergeapp(model_manager)
|
||||||
mergeapp.run()
|
mergeapp.run()
|
||||||
|
|
||||||
args = mergeapp.merge_arguments
|
args = mergeapp.merge_arguments
|
||||||
merge_diffusion_models_and_save(**args)
|
merger = ModelMerger(model_manager)
|
||||||
|
merger.merge_diffusion_models_and_save(**args)
|
||||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
||||||
|
|
||||||
|
|
||||||
def run_cli(args: Namespace):
|
def run_cli(args: Namespace):
|
||||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||||
assert (
|
assert (
|
||||||
args.models and len(args.models) >= 1 and len(args.models) <= 3
|
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
|
||||||
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
|
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
|
||||||
|
|
||||||
if not args.merged_model_name:
|
if not args.merged_model_name:
|
||||||
args.merged_model_name = "+".join(args.models)
|
args.merged_model_name = "+".join(args.model_names)
|
||||||
logger.info(
|
logger.info(
|
||||||
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||||
)
|
)
|
||||||
|
|
||||||
model_manager = ModelManager(OmegaConf.load(config.model_conf_path))
|
model_manager = ModelManager(config.model_conf_path)
|
||||||
assert (
|
assert (
|
||||||
args.clobber or args.merged_model_name not in model_manager.model_names()
|
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
|
||||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||||
|
|
||||||
merge_diffusion_models_and_commit(**vars(args))
|
merger = ModelMerger(model_manager)
|
||||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
merger.merge_diffusion_models_and_save(**vars(args))
|
||||||
|
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = _parse_args()
|
args = _parse_args()
|
||||||
config.parse_args(['--root',args.root_dir])
|
config.parse_args(['--root',str(args.root_dir)])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if args.front_end:
|
if args.front_end:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user