mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Port the command-line tools to use model_manager2 (#5546)
* Port the command-line tools to use model_manager2 1.Reimplement the following: - invokeai-model-install - invokeai-merge - invokeai-ti To avoid breaking the original modeal manager, the udpated tools have been renamed invokeai-model-install2 and invokeai-merge2. The textual inversion training script should continue to work with existing installations. The "starter" models now live in `invokeai/configs/INITIAL_MODELS2.yaml`. When the full model manager 2 is in place and working, I'll rename these files and commands. 2. Add the `merge` route to the web API. This will merge two or three models, resulting a new one. - Note that because the model installer selectively installs the `fp16` variant of models (rather than both 16- and 32-bit versions as previous), the diffusers merge script will choke on any huggingface diffuserse models that were downloaded with the new installer. Previously-downloaded models should continue to merge correctly. I have a PR upstream https://github.com/huggingface/diffusers/pull/6670 to fix this. 3. (more important!) During implementation of the CLI tools, found and fixed a number of small runtime bugs in the model_manager2 implementation: - During model database migration, if a registered models file was not found on disk, the migration would be aborted. Now the offending model is skipped with a log warning. - Caught and fixed a condition in which the installer would download the entire diffusers repo when the user provided a single `.safetensors` file URL. - Caught and fixed a condition in which the installer would raise an exception and stop the app when a request for an unknown model's metadata was passed to Civitai. Now an error is logged and the installer continues. - Replaced the LoWRA starter LoRA with FlatColor. The former has been removed from Civitai. * fix ruff issue --------- Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
@ -11,6 +11,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@ -30,8 +31,6 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
@ -41,8 +40,8 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# invokeai stuff
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser
|
||||
from invokeai.app.services.model_manager import ModelManagerService
|
||||
from invokeai.backend.model_management.models import SubModelType
|
||||
from invokeai.backend.install.install_helper import initialize_record_store
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
@ -77,7 +76,7 @@ def save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_t
|
||||
torch.save(learned_embeds_dict, save_path)
|
||||
|
||||
|
||||
def parse_args():
|
||||
def parse_args() -> Namespace:
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
parser = PagingArgumentParser(description="Textual inversion training")
|
||||
general_group = parser.add_argument_group("General")
|
||||
@ -444,7 +443,7 @@ class TextualInversionDataset(Dataset):
|
||||
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
@ -509,11 +508,10 @@ def do_textual_inversion_training(
|
||||
initializer_token: str,
|
||||
save_steps: int = 500,
|
||||
only_save_embeds: bool = False,
|
||||
revision: str = None,
|
||||
tokenizer_name: str = None,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
learnable_property: str = "object",
|
||||
repeats: int = 100,
|
||||
seed: int = None,
|
||||
seed: Optional[int] = None,
|
||||
resolution: int = 512,
|
||||
center_crop: bool = False,
|
||||
train_batch_size: int = 16,
|
||||
@ -530,18 +528,18 @@ def do_textual_inversion_training(
|
||||
adam_weight_decay: float = 1e-02,
|
||||
adam_epsilon: float = 1e-08,
|
||||
push_to_hub: bool = False,
|
||||
hub_token: str = None,
|
||||
hub_token: Optional[str] = None,
|
||||
logging_dir: Path = Path("logs"),
|
||||
mixed_precision: str = "fp16",
|
||||
allow_tf32: bool = False,
|
||||
report_to: str = "tensorboard",
|
||||
local_rank: int = -1,
|
||||
checkpointing_steps: int = 500,
|
||||
resume_from_checkpoint: Path = None,
|
||||
resume_from_checkpoint: Optional[Path] = None,
|
||||
enable_xformers_memory_efficient_attention: bool = False,
|
||||
hub_model_id: str = None,
|
||||
hub_model_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
assert model, "Please specify a base model with --model"
|
||||
assert train_data_dir, "Please specify a directory containing the training images using --train_data_dir"
|
||||
assert placeholder_token, "Please specify a trigger term using --placeholder_token"
|
||||
@ -564,8 +562,6 @@ def do_textual_inversion_training(
|
||||
project_config=accelerator_config,
|
||||
)
|
||||
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@ -603,44 +599,37 @@ def do_textual_inversion_training(
|
||||
elif output_dir is not None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
known_models = model_manager.model_names()
|
||||
model_name = model.split("/")[-1]
|
||||
model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
|
||||
assert model_meta is not None, f"Unknown model: {model}"
|
||||
model_info = model_manager.model_info(*model_meta)
|
||||
assert model_info["model_format"] == "diffusers", "This script only works with models of type 'diffusers'"
|
||||
tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
|
||||
noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
|
||||
text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
|
||||
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
|
||||
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
|
||||
model_records = initialize_record_store(config)
|
||||
base, type, name = model.split("/") # note frontend still returns old-style keys
|
||||
try:
|
||||
model_config = model_records.search_by_attr(
|
||||
model_name=name, model_type=ModelType(type), base_model=BaseModelType(base)
|
||||
)[0]
|
||||
except IndexError:
|
||||
raise Exception(f"Unknown model {model}")
|
||||
model_path = config.models_path / model_config.path
|
||||
|
||||
pipeline_args = {"local_files_only": True}
|
||||
if tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
||||
else:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_info.location, subfolder="tokenizer", **pipeline_args)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", **pipeline_args)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(
|
||||
noise_scheduler_info.location, subfolder="scheduler", **pipeline_args
|
||||
)
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler", **pipeline_args)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
text_encoder_info.location,
|
||||
model_path,
|
||||
subfolder="text_encoder",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_info.location,
|
||||
model_path,
|
||||
subfolder="vae",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
unet_info.location,
|
||||
model_path,
|
||||
subfolder="unet",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
)
|
||||
|
||||
@ -728,7 +717,7 @@ def do_textual_inversion_training(
|
||||
max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
scheduler = get_scheduler(
|
||||
lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
||||
@ -737,7 +726,7 @@ def do_textual_inversion_training(
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
text_encoder, optimizer, train_dataloader, scheduler
|
||||
)
|
||||
|
||||
# For mixed precision training we cast the unet and vae weights to half-precision
|
||||
@ -863,7 +852,7 @@ def do_textual_inversion_training(
|
||||
accelerator.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
@ -893,7 +882,7 @@ def do_textual_inversion_training(
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
logs = {"loss": loss.detach().item(), "lr": scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
@ -910,7 +899,7 @@ def do_textual_inversion_training(
|
||||
save_full_model = not only_save_embeds
|
||||
if save_full_model:
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
unet_info.location,
|
||||
model_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
|
Reference in New Issue
Block a user