mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Apply black
This commit is contained in:
@ -2,4 +2,3 @@
|
||||
Initialization file for invokeai.frontend.merge
|
||||
"""
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers
|
||||
|
||||
|
@ -20,13 +20,17 @@ from omegaconf import OmegaConf
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import (
|
||||
ModelMerger, MergeInterpolationMethod,
|
||||
ModelManager, ModelType, BaseModelType,
|
||||
ModelMerger,
|
||||
MergeInterpolationMethod,
|
||||
ModelManager,
|
||||
ModelType,
|
||||
BaseModelType,
|
||||
)
|
||||
from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||
parser.add_argument(
|
||||
@ -134,14 +138,14 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
self.base_select = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=[
|
||||
'Models Built on SD-1.x',
|
||||
'Models Built on SD-2.x',
|
||||
"Models Built on SD-1.x",
|
||||
"Models Built on SD-2.x",
|
||||
],
|
||||
value=[self.current_base],
|
||||
columns = 4,
|
||||
max_height = 2,
|
||||
columns=4,
|
||||
max_height=2,
|
||||
relx=8,
|
||||
scroll_exit = True,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.base_select.on_changed = self._populate_models
|
||||
self.add_widget_intelligent(
|
||||
@ -300,15 +304,11 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
def validate_field_values(self) -> bool:
|
||||
bad_fields = []
|
||||
model_names = self.model_names
|
||||
selected_models = set(
|
||||
(model_names[self.model1.value[0]], model_names[self.model2.value[0]])
|
||||
)
|
||||
selected_models = set((model_names[self.model1.value[0]], model_names[self.model2.value[0]]))
|
||||
if self.model3.value[0] > 0:
|
||||
selected_models.add(model_names[self.model3.value[0] - 1])
|
||||
if len(selected_models) < 2:
|
||||
bad_fields.append(
|
||||
f"Please select two or three DIFFERENT models to compare. You selected {selected_models}"
|
||||
)
|
||||
bad_fields.append(f"Please select two or three DIFFERENT models to compare. You selected {selected_models}")
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:"
|
||||
for problem in bad_fields:
|
||||
@ -318,7 +318,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self, base_model: BaseModelType=None) -> List[str]:
|
||||
def get_model_names(self, base_model: BaseModelType = None) -> List[str]:
|
||||
model_names = [
|
||||
info["name"]
|
||||
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
||||
@ -326,20 +326,21 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
]
|
||||
return sorted(model_names)
|
||||
|
||||
def _populate_models(self,value=None):
|
||||
def _populate_models(self, value=None):
|
||||
base_model = tuple(BaseModelType)[value[0]]
|
||||
self.model_names = self.get_model_names(base_model)
|
||||
|
||||
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
self.model1.values = self.model_names
|
||||
self.model2.values = self.model_names
|
||||
self.model3.values = models_plus_none
|
||||
|
||||
|
||||
self.display()
|
||||
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self, model_manager:ModelManager):
|
||||
def __init__(self, model_manager: ModelManager):
|
||||
super().__init__()
|
||||
self.model_manager = model_manager
|
||||
|
||||
@ -367,9 +368,7 @@ def run_cli(args: Namespace):
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.model_names)
|
||||
logger.info(
|
||||
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
)
|
||||
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
|
||||
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
assert (
|
||||
@ -383,7 +382,7 @@ def run_cli(args: Namespace):
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
config.parse_args(['--root',str(args.root_dir)])
|
||||
config.parse_args(["--root", str(args.root_dir)])
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
@ -392,13 +391,9 @@ def main():
|
||||
run_cli(args)
|
||||
except widget.NotEnoughSpaceForWidget as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
logger.error(
|
||||
"You need to have at least two diffusers models defined in models.yaml in order to merge"
|
||||
)
|
||||
logger.error("You need to have at least two diffusers models defined in models.yaml in order to merge")
|
||||
else:
|
||||
logger.error(
|
||||
"Not enough room for the user interface. Try making this window larger."
|
||||
)
|
||||
logger.error("Not enough room for the user interface. Try making this window larger.")
|
||||
sys.exit(-1)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
Reference in New Issue
Block a user