diff --git a/invokeai/backend/training/textual_inversion_training.py b/invokeai/backend/training/textual_inversion_training.py index 153bd0fcc4..5cd77119d8 100644 --- a/invokeai/backend/training/textual_inversion_training.py +++ b/invokeai/backend/training/textual_inversion_training.py @@ -11,6 +11,7 @@ import logging import math import os import random +import re from pathlib import Path from typing import Optional @@ -41,8 +42,8 @@ from transformers import CLIPTextModel, CLIPTokenizer # invokeai stuff from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser -from invokeai.app.services.model_manager_service import ModelManagerService -from invokeai.backend.model_management.models import SubModelType +from invokeai.app.services.model_manager_service import BaseModelType, ModelManagerService, ModelType +from invokeai.backend.model_manager import SubModelType if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): PIL_INTERPOLATION = { @@ -66,7 +67,6 @@ else: # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.10.0.dev0") - logger = get_logger(__name__) @@ -114,7 +114,6 @@ def parse_args(): general_group.add_argument( "--output_dir", type=Path, - default=f"{config.root}/text-inversion-model", help="The output directory where the model predictions and checkpoints will be written.", ) model_group.add_argument( @@ -550,8 +549,11 @@ def do_textual_inversion_training( local_rank = env_local_rank # setting up things the way invokeai expects them + output_dir = output_dir or config.root_path / "text-inversion-output" + + print(f"output_dir={output_dir}") if not os.path.isabs(output_dir): - output_dir = os.path.join(config.root, output_dir) + output_dir = Path(config.root, output_dir) logging_dir = output_dir / logging_dir @@ -564,14 +566,15 @@ def do_textual_inversion_training( project_config=accelerator_config, ) - model_manager = ModelManagerService(config, logger) + model_manager = ModelManagerService(config) + # The InvokeAI logger already does this... # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) + # logging.basicConfig( + # format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + # datefmt="%m/%d/%Y %H:%M:%S", + # level=logging.INFO, + # ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() @@ -603,17 +606,30 @@ def do_textual_inversion_training( elif output_dir is not None: os.makedirs(output_dir, exist_ok=True) - known_models = model_manager.model_names() - model_name = model.split("/")[-1] - model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None) - assert model_meta is not None, f"Unknown model: {model}" - model_info = model_manager.model_info(*model_meta) - assert model_info["model_format"] == "diffusers", "This script only works with models of type 'diffusers'" - tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer) - noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler) - text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder) - vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae) - unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet) + if len(model) == 32 and re.match(r"^[0-9a-f]+$", model): # looks like a key, not a model name + model_key = model + else: + parts = model.split("/") + if len(parts) == 3: + base_model, model_type, model_name = parts + else: + model_name = parts[-1] + base_model = BaseModelType("sd-1") + model_type = ModelType.Main + models = model_manager.list_models( + model_name=model_name, + base_model=base_model, + model_type=model_type, + ) + assert len(models) > 0, f"Unknown model: {model}" + assert len(models) < 2, "More than one model named {model_name}. Please pass key instead." + model_key = models[0].key + + tokenizer_info = model_manager.get_model(model_key, submodel_type=SubModelType.Tokenizer) + noise_scheduler_info = model_manager.get_model(model_key, submodel_type=SubModelType.Scheduler) + text_encoder_info = model_manager.get_model(model_key, submodel_type=SubModelType.TextEncoder) + vae_info = model_manager.get_model(model_key, submodel_type=SubModelType.Vae) + unet_info = model_manager.get_model(model_key, submodel_type=SubModelType.UNet) pipeline_args = dict(local_files_only=True) if tokenizer_name: diff --git a/invokeai/frontend/training/textual_inversion.py b/invokeai/frontend/training/textual_inversion.py index f3911f7e0e..7236511ddb 100755 --- a/invokeai/frontend/training/textual_inversion.py +++ b/invokeai/frontend/training/textual_inversion.py @@ -22,6 +22,7 @@ from omegaconf import OmegaConf import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_manager import ModelConfigStore, ModelType, get_config_store from ...backend.training import do_textual_inversion_training, parse_args @@ -275,10 +276,13 @@ class textualInversionForm(npyscreen.FormMultiPageAction): return True def get_model_names(self) -> Tuple[List[str], int]: - conf = OmegaConf.load(config.root_dir / "configs/models.yaml") - model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get("format", None) == "diffusers"] - defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]] - default = defaults[0] if len(defaults) > 0 else 0 + global config + store: ModelConfigStore = get_config_store(config.model_conf_path) + main_models = store.search_by_name(model_type=ModelType.Main) + model_names = [ + f"{x.base_model.value}/{x.model_type.value}/{x.name}" for x in main_models if x.model_format == "diffusers" + ] + default = 0 return (model_names, default) def marshall_arguments(self) -> dict: @@ -384,6 +388,7 @@ def previous_args() -> dict: def do_front_end(args: Namespace): + global config saved_args = previous_args() myapplication = MyApplication(saved_args=saved_args) myapplication.run() @@ -399,7 +404,7 @@ def do_front_end(args: Namespace): save_args(args) try: - do_textual_inversion_training(InvokeAIAppConfig.get_config(), **args) + do_textual_inversion_training(config, **args) copy_to_embeddings_folder(args) except Exception as e: logger.error("An exception occurred during training. The exception was:") @@ -413,6 +418,7 @@ def main(): args = parse_args() config = InvokeAIAppConfig.get_config() + config.parse_args([]) # change root if needed if args.root_dir: