mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make textual inversion training work with new model manager
This commit is contained in:
@ -11,6 +11,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@ -41,8 +42,8 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# invokeai stuff
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser
|
||||
from invokeai.app.services.model_manager_service import ModelManagerService
|
||||
from invokeai.backend.model_management.models import SubModelType
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelManagerService, ModelType
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
@ -66,7 +67,6 @@ else:
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.10.0.dev0")
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ -114,7 +114,6 @@ def parse_args():
|
||||
general_group.add_argument(
|
||||
"--output_dir",
|
||||
type=Path,
|
||||
default=f"{config.root}/text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
@ -550,8 +549,11 @@ def do_textual_inversion_training(
|
||||
local_rank = env_local_rank
|
||||
|
||||
# setting up things the way invokeai expects them
|
||||
output_dir = output_dir or config.root_path / "text-inversion-output"
|
||||
|
||||
print(f"output_dir={output_dir}")
|
||||
if not os.path.isabs(output_dir):
|
||||
output_dir = os.path.join(config.root, output_dir)
|
||||
output_dir = Path(config.root, output_dir)
|
||||
|
||||
logging_dir = output_dir / logging_dir
|
||||
|
||||
@ -564,14 +566,15 @@ def do_textual_inversion_training(
|
||||
project_config=accelerator_config,
|
||||
)
|
||||
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_manager = ModelManagerService(config)
|
||||
|
||||
# The InvokeAI logger already does this...
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
# logging.basicConfig(
|
||||
# format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
# datefmt="%m/%d/%Y %H:%M:%S",
|
||||
# level=logging.INFO,
|
||||
# )
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
@ -603,17 +606,30 @@ def do_textual_inversion_training(
|
||||
elif output_dir is not None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
known_models = model_manager.model_names()
|
||||
model_name = model.split("/")[-1]
|
||||
model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
|
||||
assert model_meta is not None, f"Unknown model: {model}"
|
||||
model_info = model_manager.model_info(*model_meta)
|
||||
assert model_info["model_format"] == "diffusers", "This script only works with models of type 'diffusers'"
|
||||
tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
|
||||
noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
|
||||
text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
|
||||
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
|
||||
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
|
||||
if len(model) == 32 and re.match(r"^[0-9a-f]+$", model): # looks like a key, not a model name
|
||||
model_key = model
|
||||
else:
|
||||
parts = model.split("/")
|
||||
if len(parts) == 3:
|
||||
base_model, model_type, model_name = parts
|
||||
else:
|
||||
model_name = parts[-1]
|
||||
base_model = BaseModelType("sd-1")
|
||||
model_type = ModelType.Main
|
||||
models = model_manager.list_models(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
assert len(models) > 0, f"Unknown model: {model}"
|
||||
assert len(models) < 2, "More than one model named {model_name}. Please pass key instead."
|
||||
model_key = models[0].key
|
||||
|
||||
tokenizer_info = model_manager.get_model(model_key, submodel_type=SubModelType.Tokenizer)
|
||||
noise_scheduler_info = model_manager.get_model(model_key, submodel_type=SubModelType.Scheduler)
|
||||
text_encoder_info = model_manager.get_model(model_key, submodel_type=SubModelType.TextEncoder)
|
||||
vae_info = model_manager.get_model(model_key, submodel_type=SubModelType.Vae)
|
||||
unet_info = model_manager.get_model(model_key, submodel_type=SubModelType.UNet)
|
||||
|
||||
pipeline_args = dict(local_files_only=True)
|
||||
if tokenizer_name:
|
||||
|
@ -22,6 +22,7 @@ from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import ModelConfigStore, ModelType, get_config_store
|
||||
|
||||
from ...backend.training import do_textual_inversion_training, parse_args
|
||||
|
||||
@ -275,10 +276,13 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> Tuple[List[str], int]:
|
||||
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
|
||||
model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get("format", None) == "diffusers"]
|
||||
defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]]
|
||||
default = defaults[0] if len(defaults) > 0 else 0
|
||||
global config
|
||||
store: ModelConfigStore = get_config_store(config.model_conf_path)
|
||||
main_models = store.search_by_name(model_type=ModelType.Main)
|
||||
model_names = [
|
||||
f"{x.base_model.value}/{x.model_type.value}/{x.name}" for x in main_models if x.model_format == "diffusers"
|
||||
]
|
||||
default = 0
|
||||
return (model_names, default)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
@ -384,6 +388,7 @@ def previous_args() -> dict:
|
||||
|
||||
|
||||
def do_front_end(args: Namespace):
|
||||
global config
|
||||
saved_args = previous_args()
|
||||
myapplication = MyApplication(saved_args=saved_args)
|
||||
myapplication.run()
|
||||
@ -399,7 +404,7 @@ def do_front_end(args: Namespace):
|
||||
save_args(args)
|
||||
|
||||
try:
|
||||
do_textual_inversion_training(InvokeAIAppConfig.get_config(), **args)
|
||||
do_textual_inversion_training(config, **args)
|
||||
copy_to_embeddings_folder(args)
|
||||
except Exception as e:
|
||||
logger.error("An exception occurred during training. The exception was:")
|
||||
@ -413,6 +418,7 @@ def main():
|
||||
|
||||
args = parse_args()
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args([])
|
||||
|
||||
# change root if needed
|
||||
if args.root_dir:
|
||||
|
Reference in New Issue
Block a user