make textual inversion training work with new model manager

This commit is contained in:
Lincoln Stein
2023-10-02 22:23:49 -04:00
parent 63f6c12aa3
commit 48c3d926b0
2 changed files with 49 additions and 27 deletions

View File

@ -11,6 +11,7 @@ import logging
import math
import os
import random
import re
from pathlib import Path
from typing import Optional
@ -41,8 +42,8 @@ from transformers import CLIPTextModel, CLIPTokenizer
# invokeai stuff
from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser
from invokeai.app.services.model_manager_service import ModelManagerService
from invokeai.backend.model_management.models import SubModelType
from invokeai.app.services.model_manager_service import BaseModelType, ModelManagerService, ModelType
from invokeai.backend.model_manager import SubModelType
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
@ -66,7 +67,6 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = get_logger(__name__)
@ -114,7 +114,6 @@ def parse_args():
general_group.add_argument(
"--output_dir",
type=Path,
default=f"{config.root}/text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
model_group.add_argument(
@ -550,8 +549,11 @@ def do_textual_inversion_training(
local_rank = env_local_rank
# setting up things the way invokeai expects them
output_dir = output_dir or config.root_path / "text-inversion-output"
print(f"output_dir={output_dir}")
if not os.path.isabs(output_dir):
output_dir = os.path.join(config.root, output_dir)
output_dir = Path(config.root, output_dir)
logging_dir = output_dir / logging_dir
@ -564,14 +566,15 @@ def do_textual_inversion_training(
project_config=accelerator_config,
)
model_manager = ModelManagerService(config, logger)
model_manager = ModelManagerService(config)
# The InvokeAI logger already does this...
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# logging.basicConfig(
# format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
# datefmt="%m/%d/%Y %H:%M:%S",
# level=logging.INFO,
# )
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
@ -603,17 +606,30 @@ def do_textual_inversion_training(
elif output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
known_models = model_manager.model_names()
model_name = model.split("/")[-1]
model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
assert model_meta is not None, f"Unknown model: {model}"
model_info = model_manager.model_info(*model_meta)
assert model_info["model_format"] == "diffusers", "This script only works with models of type 'diffusers'"
tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
if len(model) == 32 and re.match(r"^[0-9a-f]+$", model): # looks like a key, not a model name
model_key = model
else:
parts = model.split("/")
if len(parts) == 3:
base_model, model_type, model_name = parts
else:
model_name = parts[-1]
base_model = BaseModelType("sd-1")
model_type = ModelType.Main
models = model_manager.list_models(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
assert len(models) > 0, f"Unknown model: {model}"
assert len(models) < 2, "More than one model named {model_name}. Please pass key instead."
model_key = models[0].key
tokenizer_info = model_manager.get_model(model_key, submodel_type=SubModelType.Tokenizer)
noise_scheduler_info = model_manager.get_model(model_key, submodel_type=SubModelType.Scheduler)
text_encoder_info = model_manager.get_model(model_key, submodel_type=SubModelType.TextEncoder)
vae_info = model_manager.get_model(model_key, submodel_type=SubModelType.Vae)
unet_info = model_manager.get_model(model_key, submodel_type=SubModelType.UNet)
pipeline_args = dict(local_files_only=True)
if tokenizer_name:

View File

@ -22,6 +22,7 @@ from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import ModelConfigStore, ModelType, get_config_store
from ...backend.training import do_textual_inversion_training, parse_args
@ -275,10 +276,13 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
return True
def get_model_names(self) -> Tuple[List[str], int]:
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get("format", None) == "diffusers"]
defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]]
default = defaults[0] if len(defaults) > 0 else 0
global config
store: ModelConfigStore = get_config_store(config.model_conf_path)
main_models = store.search_by_name(model_type=ModelType.Main)
model_names = [
f"{x.base_model.value}/{x.model_type.value}/{x.name}" for x in main_models if x.model_format == "diffusers"
]
default = 0
return (model_names, default)
def marshall_arguments(self) -> dict:
@ -384,6 +388,7 @@ def previous_args() -> dict:
def do_front_end(args: Namespace):
global config
saved_args = previous_args()
myapplication = MyApplication(saved_args=saved_args)
myapplication.run()
@ -399,7 +404,7 @@ def do_front_end(args: Namespace):
save_args(args)
try:
do_textual_inversion_training(InvokeAIAppConfig.get_config(), **args)
do_textual_inversion_training(config, **args)
copy_to_embeddings_folder(args)
except Exception as e:
logger.error("An exception occurred during training. The exception was:")
@ -413,6 +418,7 @@ def main():
args = parse_args()
config = InvokeAIAppConfig.get_config()
config.parse_args([])
# change root if needed
if args.root_dir: