mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
clean up merge_models
This commit is contained in:
parent
d3a469d136
commit
9dc3832b9b
@ -708,7 +708,7 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
|
||||
if not ckpt_path.is_absolute():
|
||||
ckpt_path = Path(Globals.root,ckpt_path)
|
||||
|
||||
diffuser_path = Path(Globals.root, 'models','optimized-ckpts',model_name)
|
||||
diffuser_path = Path(Globals.root, 'models',Globals.converted_ckpts_dir,model_name)
|
||||
if diffuser_path.exists():
|
||||
print(f'** {model_name_or_path} is already optimized. Will not overwrite. If this is an error, please remove the directory {diffuser_path} and try again.')
|
||||
return
|
||||
|
@ -33,6 +33,7 @@ Globals.models_file = 'models.yaml'
|
||||
Globals.models_dir = 'models'
|
||||
Globals.config_dir = 'configs'
|
||||
Globals.autoscan_dir = 'weights'
|
||||
Globals.converted_ckpts_dir = 'converted-ckpts'
|
||||
|
||||
# Try loading patchmatch
|
||||
Globals.try_patchmatch = True
|
||||
|
@ -1,21 +1,74 @@
|
||||
'''
|
||||
"""
|
||||
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
|
||||
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
|
||||
'''
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
from typing import List
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import npyscreen
|
||||
from diffusers import DiffusionPipeline
|
||||
from ldm.invoke.globals import global_config_file, global_models_dir, global_cache_dir
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
def merge_diffusion_models(models:List['str'],
|
||||
merged_model_name:str,
|
||||
alpha:float=0.5,
|
||||
interp:str=None,
|
||||
force:bool=False,
|
||||
**kwargs):
|
||||
'''
|
||||
from ldm.invoke.globals import (
|
||||
Globals,
|
||||
global_cache_dir,
|
||||
global_config_file,
|
||||
global_models_dir,
|
||||
global_set_root,
|
||||
)
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
|
||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
||||
|
||||
|
||||
def merge_diffusion_models(
|
||||
model_ids_or_paths: List[Union[str, Path]],
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
return merged_pipe
|
||||
|
||||
|
||||
def merge_diffusion_models_and_commit(
|
||||
models: List["str"],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
models - up to three models, designated by their InvokeAI models.yaml model name
|
||||
merged_model_name = name for new model
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
@ -26,37 +79,303 @@ def merge_diffusion_models(models:List['str'],
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
'''
|
||||
"""
|
||||
config_file = global_config_file()
|
||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||
for mod in models:
|
||||
assert (mod in model_manager.model_names()), f'** Unknown model "{mod}"'
|
||||
assert (model_manager.model_info(mod).get('format',None) == 'diffusers'), f'** {mod} is not a diffusers model. It must be optimized before merging.'
|
||||
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
|
||||
assert (
|
||||
model_manager.model_info(mod).get("format", None) == "diffusers"
|
||||
), f"** {mod} is not a diffusers model. It must be optimized before merging."
|
||||
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get('cache_dir',global_cache_dir()),
|
||||
custom_pipeline='checkpoint_merger')
|
||||
merged_pipe = pipe.merge(pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
**kwargs)
|
||||
dump_path = global_models_dir() / 'merged_diffusers'
|
||||
os.makedirs(dump_path,exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
merged_pipe.save_pretrained (
|
||||
dump_path,
|
||||
safe_serialization=1
|
||||
merged_pipe = merge_diffusion_models(
|
||||
model_ids_or_paths, alpha, interp, force, **kwargs
|
||||
)
|
||||
model_manager.import_diffuser_model(
|
||||
dump_path,
|
||||
model_name = merged_model_name,
|
||||
description = f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
print('REMINDER: When PR 2369 is merged, replace merge_diffusers.py line 56 with vae= argument to impormodel()')
|
||||
if vae := model_manager.config[models[0]].get('vae',None):
|
||||
print(f'>> Using configured VAE assigned to {models[0]}')
|
||||
model_manager.config[merged_model_name]['vae'] = vae
|
||||
dump_path = global_models_dir() / DEST_MERGED_MODEL_DIR
|
||||
|
||||
os.makedirs(dump_path, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
import_args = dict(
|
||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
if vae := model_manager.config[models[0]].get("vae", None):
|
||||
print(f">> Using configured VAE assigned to {models[0]}")
|
||||
import_args.update(vae=vae)
|
||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
||||
model_manager.commit(config_file)
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help="Path to the invokeai runtime directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--front_end",
|
||||
"--gui",
|
||||
dest="front_end",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Two to three model names to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merged_model_name",
|
||||
"--destination",
|
||||
dest="merged_model_name",
|
||||
type=str,
|
||||
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation",
|
||||
dest="interp",
|
||||
type=str,
|
||||
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
|
||||
default="weighted_sum",
|
||||
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Try to merge models even if they are incompatible with each other",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clobber",
|
||||
"--overwrite",
|
||||
dest="clobber",
|
||||
action="store_true",
|
||||
help="Overwrite the merged model if --merged_model_name already exists",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# ------------------------- GUI HERE -------------------------
|
||||
class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
|
||||
l = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
|
||||
class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = FloatSlider
|
||||
|
||||
|
||||
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"]
|
||||
|
||||
def __init__(self, parentApp, name):
|
||||
self.parentApp = parentApp
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
@property
|
||||
def model_manager(self):
|
||||
return self.parentApp.model_manager
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
self.model_names = self.get_model_names()
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText, name="Select up to three models to merge", value=""
|
||||
)
|
||||
self.models = self.add_widget_intelligent(
|
||||
npyscreen.TitleMultiSelect,
|
||||
name="Select two to three models to merge:",
|
||||
values=self.model_names,
|
||||
value=None,
|
||||
max_height=len(self.model_names) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.models.when_value_edited = self.models_changed
|
||||
self.merged_model_name = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Name for merged model:",
|
||||
value="",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.force = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Force merge of incompatible models",
|
||||
value=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.merge_method = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Merge Method:",
|
||||
values=self.interpolations,
|
||||
value=0,
|
||||
max_height=len(self.interpolations) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.alpha = self.add_widget_intelligent(
|
||||
FloatTitleSlider,
|
||||
name="Weight (alpha) to assign to second and third models:",
|
||||
out_of=1,
|
||||
step=0.05,
|
||||
lowest=0,
|
||||
value=0.5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.models.editing = True
|
||||
|
||||
def models_changed(self):
|
||||
model_names = self.models.values
|
||||
selected_models = self.models.value
|
||||
if len(selected_models) > 3:
|
||||
npyscreen.notify_confirm(
|
||||
"Too many models selected for merging. Select two to three."
|
||||
)
|
||||
return
|
||||
elif len(selected_models) > 2:
|
||||
self.merge_method.values = ["add_difference"]
|
||||
self.merge_method.value = 0
|
||||
else:
|
||||
self.merge_method.values = self.interpolations
|
||||
self.merged_model_name.value = "+".join(
|
||||
[model_names[x] for x in selected_models]
|
||||
)
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values() and self.check_for_overwrite():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.merge_arguments = self.marshall_arguments()
|
||||
npyscreen.notify("Starting the merge...")
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def on_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
models = [self.models.values[x] for x in self.models.value]
|
||||
args = dict(
|
||||
models=models,
|
||||
alpha=self.alpha.value,
|
||||
interp=self.interpolations[self.merge_method.value[0]],
|
||||
force=self.force.value,
|
||||
merged_model_name=self.merged_model_name.value,
|
||||
)
|
||||
return args
|
||||
|
||||
def check_for_overwrite(self) -> bool:
|
||||
model_out = self.merged_model_name.value
|
||||
if model_out not in self.model_names:
|
||||
return True
|
||||
else:
|
||||
return npyscreen.notify_yes_no(
|
||||
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
|
||||
)
|
||||
|
||||
def validate_field_values(self) -> bool:
|
||||
bad_fields = []
|
||||
selected_models = self.models.value
|
||||
if len(selected_models) < 2 or len(selected_models) > 3:
|
||||
bad_fields.append("Please select two or three models to merge.")
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:"
|
||||
for problem in bad_fields:
|
||||
message += f"\n* {problem}"
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> List[str]:
|
||||
model_names = [
|
||||
name
|
||||
for name in self.model_manager.model_names()
|
||||
if self.model_manager.model_info(name).get("format") == "diffusers"
|
||||
]
|
||||
print(model_names)
|
||||
return sorted(model_names)
|
||||
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
conf = OmegaConf.load(global_config_file())
|
||||
self.model_manager = ModelManager(
|
||||
conf, "cpu", "float16"
|
||||
) # precision doesn't really matter here
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
|
||||
|
||||
|
||||
def run_gui(args: Namespace):
|
||||
mergeapp = Mergeapp()
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
merge_diffusion_models_and_commit(**args)
|
||||
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
|
||||
|
||||
|
||||
def run_cli(args: Namespace):
|
||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||
assert (
|
||||
len(args.models) >= 1 and len(args.models) <= 3
|
||||
), "provide 2 or 3 models to merge"
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.models)
|
||||
print(
|
||||
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
)
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
||||
assert (
|
||||
args.clobber or args.merged_model_name not in model_manager.model_names()
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merge_diffusion_models_and_commit(**vars(args))
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
global_set_root(args.root_dir)
|
||||
|
||||
cache_dir = str(global_cache_dir("diffusers"))
|
||||
os.environ[
|
||||
"HF_HOME"
|
||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
||||
args.cache_dir = cache_dir
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
run_gui(args)
|
||||
else:
|
||||
run_cli(args)
|
||||
print(f">> Conversion successful. New model is named {args.merged_model_name}")
|
||||
except Exception as e:
|
||||
print(f"** An error occurred while merging the pipelines: {str(e)}")
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -14,7 +14,6 @@ import os
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
import safetensors.torch
|
||||
from pathlib import Path
|
||||
@ -639,7 +638,7 @@ class ModelManager(object):
|
||||
and import.
|
||||
'''
|
||||
weights_directory = weights_directory or global_autoscan_dir()
|
||||
dest_directory = dest_directory or Path(global_models_dir(), 'optimized-ckpts')
|
||||
dest_directory = dest_directory or Path(global_models_dir(), Globals.converted_ckpts_dir)
|
||||
|
||||
print('>> Checking for unconverted .ckpt files in {weights_directory}')
|
||||
ckpt_files = dict()
|
||||
|
@ -1,7 +1,12 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import curses
|
||||
"""
|
||||
This is the frontend to "textual_inversion_training.py".
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@ -402,6 +407,7 @@ def do_front_end(args: Namespace):
|
||||
print("** DETAILS:")
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
global_set_root(args.root_dir or Globals.root)
|
||||
@ -412,3 +418,7 @@ def main():
|
||||
do_textual_inversion_training(**vars(args))
|
||||
except AssertionError as e:
|
||||
print(str(e))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -3,6 +3,10 @@
|
||||
# on January 2, 2023
|
||||
# and modified slightly by Lincoln Stein (@lstein) to work with InvokeAI
|
||||
|
||||
"""
|
||||
This is the backend to "textual_inversion.py"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
@ -11,40 +15,41 @@ import random
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import datasets
|
||||
import diffusers
|
||||
import PIL
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
|
||||
# invokeai stuff
|
||||
from ldm.invoke.args import (
|
||||
PagingArgumentParser,
|
||||
ArgFormatter
|
||||
)
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# invokeai stuff
|
||||
from ldm.invoke.args import ArgFormatter, PagingArgumentParser
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
@ -71,34 +76,41 @@ check_min_version("0.10.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path):
|
||||
def save_progress(
|
||||
text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path
|
||||
):
|
||||
logger.info("Saving embeddings")
|
||||
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
||||
learned_embeds = (
|
||||
accelerator.unwrap_model(text_encoder)
|
||||
.get_input_embeddings()
|
||||
.weight[placeholder_token_id]
|
||||
)
|
||||
learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
|
||||
torch.save(learned_embeds_dict, save_path)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = PagingArgumentParser(
|
||||
description="Textual inversion training",
|
||||
formatter_class=ArgFormatter
|
||||
description="Textual inversion training", formatter_class=ArgFormatter
|
||||
)
|
||||
general_group = parser.add_argument_group('General')
|
||||
model_group = parser.add_argument_group('Models and Paths')
|
||||
image_group = parser.add_argument_group('Training Image Location and Options')
|
||||
trigger_group = parser.add_argument_group('Trigger Token')
|
||||
training_group = parser.add_argument_group('Training Parameters')
|
||||
checkpointing_group = parser.add_argument_group('Checkpointing and Resume')
|
||||
integration_group = parser.add_argument_group('Integration')
|
||||
general_group = parser.add_argument_group("General")
|
||||
model_group = parser.add_argument_group("Models and Paths")
|
||||
image_group = parser.add_argument_group("Training Image Location and Options")
|
||||
trigger_group = parser.add_argument_group("Trigger Token")
|
||||
training_group = parser.add_argument_group("Training Parameters")
|
||||
checkpointing_group = parser.add_argument_group("Checkpointing and Resume")
|
||||
integration_group = parser.add_argument_group("Integration")
|
||||
general_group.add_argument(
|
||||
'--front_end',
|
||||
'--gui',
|
||||
dest='front_end',
|
||||
"--front_end",
|
||||
"--gui",
|
||||
dest="front_end",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Activate the text-based graphical front end for collecting parameters. Other parameters will be ignored."
|
||||
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
|
||||
)
|
||||
general_group.add_argument(
|
||||
'--root_dir','--root',
|
||||
"--root_dir",
|
||||
"--root",
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help="Path to the invokeai runtime directory",
|
||||
@ -115,13 +127,13 @@ def parse_args():
|
||||
general_group.add_argument(
|
||||
"--output_dir",
|
||||
type=Path,
|
||||
default=f'{Globals.root}/text-inversion-model',
|
||||
default=f"{Globals.root}/text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default='stable-diffusion-1.5',
|
||||
default="stable-diffusion-1.5",
|
||||
help="Name of the diffusers model to train against, as defined in configs/models.yaml.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
@ -131,7 +143,7 @@ def parse_args():
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
|
||||
|
||||
model_group.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
@ -142,7 +154,7 @@ def parse_args():
|
||||
"--train_data_dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="A folder containing the training data."
|
||||
help="A folder containing the training data.",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--resolution",
|
||||
@ -154,28 +166,30 @@ def parse_args():
|
||||
),
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
||||
"--center_crop",
|
||||
action="store_true",
|
||||
help="Whether to center crop images before resizing to resolution",
|
||||
)
|
||||
trigger_group.add_argument(
|
||||
"--placeholder_token",
|
||||
"--trigger_term",
|
||||
dest='placeholder_token',
|
||||
dest="placeholder_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A token to use as a placeholder for the concept. This token will trigger the concept when included in the prompt as \"<trigger>\".",
|
||||
help='A token to use as a placeholder for the concept. This token will trigger the concept when included in the prompt as "<trigger>".',
|
||||
)
|
||||
trigger_group.add_argument(
|
||||
"--learnable_property",
|
||||
type=str,
|
||||
choices=['object','style'],
|
||||
choices=["object", "style"],
|
||||
default="object",
|
||||
help="Choose between 'object' and 'style'"
|
||||
help="Choose between 'object' and 'style'",
|
||||
)
|
||||
trigger_group.add_argument(
|
||||
"--initializer_token",
|
||||
type=str,
|
||||
default='*',
|
||||
help="A symbol to use as the initializer word."
|
||||
default="*",
|
||||
help="A symbol to use as the initializer word.",
|
||||
)
|
||||
checkpointing_group.add_argument(
|
||||
"--checkpointing_steps",
|
||||
@ -201,10 +215,20 @@ def parse_args():
|
||||
default=500,
|
||||
help="Save learned_embeds.bin every X updates steps.",
|
||||
)
|
||||
training_group.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
|
||||
training_group.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
training_group.add_argument(
|
||||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||
"--repeats",
|
||||
type=int,
|
||||
default=100,
|
||||
help="How many times to repeat the training data.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--seed", type=int, default=None, help="A seed for reproducible training."
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--train_batch_size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
training_group.add_argument("--num_train_epochs", type=int, default=100)
|
||||
training_group.add_argument(
|
||||
@ -246,12 +270,32 @@ def parse_args():
|
||||
),
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
"--lr_warmup_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps for the warmup in the lr scheduler.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--adam_beta1",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="The beta1 parameter for the Adam optimizer.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--adam_beta2",
|
||||
type=float,
|
||||
default=0.999,
|
||||
help="The beta2 parameter for the Adam optimizer.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
default=1e-08,
|
||||
help="Epsilon value for the Adam optimizer",
|
||||
)
|
||||
training_group.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
training_group.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
training_group.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
training_group.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
training_group.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
@ -271,9 +315,16 @@ def parse_args():
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
training_group.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
training_group.add_argument(
|
||||
"--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="For distributed training: local_rank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
"--enable_xformers_memory_efficient_attention",
|
||||
action="store_true",
|
||||
help="Whether or not to use xformers.",
|
||||
)
|
||||
|
||||
integration_group.add_argument(
|
||||
@ -297,8 +348,17 @@ def parse_args():
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
integration_group.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
integration_group.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
integration_group.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether or not to push the model to the Hub.",
|
||||
)
|
||||
integration_group.add_argument(
|
||||
"--hub_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The token to use to push to the Model Hub.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@ -378,7 +438,10 @@ class TextualInversionDataset(Dataset):
|
||||
self.center_crop = center_crop
|
||||
self.flip_p = flip_p
|
||||
|
||||
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
@ -393,7 +456,11 @@ class TextualInversionDataset(Dataset):
|
||||
"lanczos": PIL_INTERPOLATION["lanczos"],
|
||||
}[interpolation]
|
||||
|
||||
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
||||
self.templates = (
|
||||
imagenet_style_templates_small
|
||||
if learnable_property == "style"
|
||||
else imagenet_templates_small
|
||||
)
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self):
|
||||
@ -426,7 +493,9 @@ class TextualInversionDataset(Dataset):
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
@ -439,7 +508,9 @@ class TextualInversionDataset(Dataset):
|
||||
return example
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
def get_full_repo_name(
|
||||
model_id: str, organization: Optional[str] = None, token: Optional[str] = None
|
||||
):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
@ -450,58 +521,60 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
||||
|
||||
|
||||
def do_textual_inversion_training(
|
||||
model:str,
|
||||
train_data_dir:Path,
|
||||
output_dir:Path,
|
||||
placeholder_token:str,
|
||||
initializer_token:str,
|
||||
save_steps:int=500,
|
||||
only_save_embeds:bool=False,
|
||||
revision:str=None,
|
||||
tokenizer_name:str=None,
|
||||
learnable_property:str='object',
|
||||
repeats:int=100,
|
||||
seed:int=None,
|
||||
resolution:int=512,
|
||||
center_crop:bool=False,
|
||||
train_batch_size:int=16,
|
||||
num_train_epochs:int=100,
|
||||
max_train_steps:int=5000,
|
||||
gradient_accumulation_steps:int=1,
|
||||
gradient_checkpointing:bool=False,
|
||||
learning_rate:float=1e-4,
|
||||
scale_lr:bool=True,
|
||||
lr_scheduler:str='constant',
|
||||
lr_warmup_steps:int=500,
|
||||
adam_beta1:float=0.9,
|
||||
adam_beta2:float=0.999,
|
||||
adam_weight_decay:float=1e-02,
|
||||
adam_epsilon:float=1e-08,
|
||||
push_to_hub:bool=False,
|
||||
hub_token:str=None,
|
||||
logging_dir:Path=Path('logs'),
|
||||
mixed_precision:str='fp16',
|
||||
allow_tf32:bool=False,
|
||||
report_to:str='tensorboard',
|
||||
local_rank:int=-1,
|
||||
checkpointing_steps:int=500,
|
||||
resume_from_checkpoint:Path=None,
|
||||
enable_xformers_memory_efficient_attention:bool=False,
|
||||
root_dir:Path=None,
|
||||
hub_model_id:str=None,
|
||||
**kwargs,
|
||||
model: str,
|
||||
train_data_dir: Path,
|
||||
output_dir: Path,
|
||||
placeholder_token: str,
|
||||
initializer_token: str,
|
||||
save_steps: int = 500,
|
||||
only_save_embeds: bool = False,
|
||||
revision: str = None,
|
||||
tokenizer_name: str = None,
|
||||
learnable_property: str = "object",
|
||||
repeats: int = 100,
|
||||
seed: int = None,
|
||||
resolution: int = 512,
|
||||
center_crop: bool = False,
|
||||
train_batch_size: int = 16,
|
||||
num_train_epochs: int = 100,
|
||||
max_train_steps: int = 5000,
|
||||
gradient_accumulation_steps: int = 1,
|
||||
gradient_checkpointing: bool = False,
|
||||
learning_rate: float = 1e-4,
|
||||
scale_lr: bool = True,
|
||||
lr_scheduler: str = "constant",
|
||||
lr_warmup_steps: int = 500,
|
||||
adam_beta1: float = 0.9,
|
||||
adam_beta2: float = 0.999,
|
||||
adam_weight_decay: float = 1e-02,
|
||||
adam_epsilon: float = 1e-08,
|
||||
push_to_hub: bool = False,
|
||||
hub_token: str = None,
|
||||
logging_dir: Path = Path("logs"),
|
||||
mixed_precision: str = "fp16",
|
||||
allow_tf32: bool = False,
|
||||
report_to: str = "tensorboard",
|
||||
local_rank: int = -1,
|
||||
checkpointing_steps: int = 500,
|
||||
resume_from_checkpoint: Path = None,
|
||||
enable_xformers_memory_efficient_attention: bool = False,
|
||||
root_dir: Path = None,
|
||||
hub_model_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model, 'Please specify a base model with --model'
|
||||
assert train_data_dir, 'Please specify a directory containing the training images using --train_data_dir'
|
||||
assert placeholder_token, 'Please specify a trigger term using --placeholder_token'
|
||||
assert model, "Please specify a base model with --model"
|
||||
assert (
|
||||
train_data_dir
|
||||
), "Please specify a directory containing the training images using --train_data_dir"
|
||||
assert placeholder_token, "Please specify a trigger term using --placeholder_token"
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != local_rank:
|
||||
local_rank = env_local_rank
|
||||
|
||||
# setting up things the way invokeai expects them
|
||||
if not os.path.isabs(output_dir):
|
||||
output_dir = os.path.join(Globals.root,output_dir)
|
||||
|
||||
output_dir = os.path.join(Globals.root, output_dir)
|
||||
|
||||
logging_dir = output_dir / logging_dir
|
||||
|
||||
accelerator = Accelerator(
|
||||
@ -548,28 +621,49 @@ def do_textual_inversion_training(
|
||||
elif output_dir is not None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
models_conf = OmegaConf.load(os.path.join(Globals.root,'configs/models.yaml'))
|
||||
model_conf = models_conf.get(model,None)
|
||||
assert model_conf is not None,f'Unknown model: {model}'
|
||||
assert model_conf.get('format','diffusers')=='diffusers', "This script only works with models of type 'diffusers'"
|
||||
pretrained_model_name_or_path = model_conf.get('repo_id',None) or Path(model_conf.get('path'))
|
||||
assert pretrained_model_name_or_path, f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
||||
pipeline_args = dict(cache_dir=global_cache_dir('diffusers'))
|
||||
models_conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
|
||||
model_conf = models_conf.get(model, None)
|
||||
assert model_conf is not None, f"Unknown model: {model}"
|
||||
assert (
|
||||
model_conf.get("format", "diffusers") == "diffusers"
|
||||
), "This script only works with models of type 'diffusers'"
|
||||
pretrained_model_name_or_path = model_conf.get("repo_id", None) or Path(
|
||||
model_conf.get("path")
|
||||
)
|
||||
assert (
|
||||
pretrained_model_name_or_path
|
||||
), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
||||
pipeline_args = dict(cache_dir=global_cache_dir("diffusers"))
|
||||
|
||||
# Load tokenizer
|
||||
if tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name,**pipeline_args)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
||||
else:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", **pipeline_args)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="tokenizer", **pipeline_args
|
||||
)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", **pipeline_args)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, **pipeline_args
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="scheduler", **pipeline_args
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision, **pipeline_args)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="unet", revision=revision, **pipeline_args
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
@ -584,7 +678,9 @@ def do_textual_inversion_training(
|
||||
token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError(f"The initializer token must be a single token. Provided initializer={initializer_token}. Token ids={token_ids}")
|
||||
raise ValueError(
|
||||
f"The initializer token must be a single token. Provided initializer={initializer_token}. Token ids={token_ids}"
|
||||
)
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
|
||||
@ -615,7 +711,9 @@ def do_textual_inversion_training(
|
||||
if is_xformers_available():
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
raise ValueError(
|
||||
"xformers is not available. Make sure it is installed correctly"
|
||||
)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
@ -624,7 +722,10 @@ def do_textual_inversion_training(
|
||||
|
||||
if scale_lr:
|
||||
learning_rate = (
|
||||
learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
|
||||
learning_rate
|
||||
* gradient_accumulation_steps
|
||||
* train_batch_size
|
||||
* accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
@ -647,11 +748,15 @@ def do_textual_inversion_training(
|
||||
center_crop=center_crop,
|
||||
set="train",
|
||||
)
|
||||
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / gradient_accumulation_steps
|
||||
)
|
||||
if max_train_steps is None:
|
||||
max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
@ -681,7 +786,9 @@ def do_textual_inversion_training(
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / gradient_accumulation_steps
|
||||
)
|
||||
if overrode_max_train_steps:
|
||||
max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
@ -691,18 +798,22 @@ def do_textual_inversion_training(
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
params = locals()
|
||||
for k in params: # init_trackers() doesn't like objects
|
||||
params[k] = str(params[k]) if isinstance(params[k],object) else params[k]
|
||||
for k in params: # init_trackers() doesn't like objects
|
||||
params[k] = str(params[k]) if isinstance(params[k], object) else params[k]
|
||||
accelerator.init_trackers("textual_inversion", config=params)
|
||||
|
||||
# Train!
|
||||
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
||||
total_batch_size = (
|
||||
train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
||||
)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
global_step = 0
|
||||
@ -719,7 +830,7 @@ def do_textual_inversion_training(
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
@ -732,34 +843,57 @@ def do_textual_inversion_training(
|
||||
|
||||
resume_global_step = global_step * gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * gradient_accumulation_steps)
|
||||
|
||||
resume_step = resume_global_step % (
|
||||
num_update_steps_per_epoch * gradient_accumulation_steps
|
||||
)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar = tqdm(
|
||||
range(global_step, max_train_steps),
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
# keep original embeddings as reference
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
|
||||
orig_embeds_params = (
|
||||
accelerator.unwrap_model(text_encoder)
|
||||
.get_input_embeddings()
|
||||
.weight.data.clone()
|
||||
)
|
||||
|
||||
for epoch in range(first_epoch, num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if resume_step and resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if (
|
||||
resume_step
|
||||
and resume_from_checkpoint
|
||||
and epoch == first_epoch
|
||||
and step < resume_step
|
||||
):
|
||||
if step % gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
|
||||
latents = (
|
||||
vae.encode(batch["pixel_values"].to(dtype=weight_dtype))
|
||||
.latent_dist.sample()
|
||||
.detach()
|
||||
)
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.config.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
@ -767,10 +901,14 @@ def do_textual_inversion_training(
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(
|
||||
dtype=weight_dtype
|
||||
)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
model_pred = unet(
|
||||
noisy_latents, timesteps, encoder_hidden_states
|
||||
).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
@ -778,7 +916,9 @@ def do_textual_inversion_training(
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
raise ValueError(
|
||||
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
||||
)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
@ -791,21 +931,35 @@ def do_textual_inversion_training(
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
accelerator.unwrap_model(
|
||||
text_encoder
|
||||
).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
] = orig_embeds_params[
|
||||
index_no_updates
|
||||
]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
if global_step % save_steps == 0:
|
||||
save_path = os.path.join(output_dir, f"learned_embeds-steps-{global_step}.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path)
|
||||
save_path = os.path.join(
|
||||
output_dir, f"learned_embeds-steps-{global_step}.bin"
|
||||
)
|
||||
save_progress(
|
||||
text_encoder,
|
||||
placeholder_token_id,
|
||||
accelerator,
|
||||
placeholder_token,
|
||||
save_path,
|
||||
)
|
||||
|
||||
if global_step % checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
|
||||
save_path = os.path.join(
|
||||
output_dir, f"checkpoint-{global_step}"
|
||||
)
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
@ -820,7 +974,9 @@ def do_textual_inversion_training(
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
if push_to_hub and only_save_embeds:
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
logger.warn(
|
||||
"Enabling full model saving because --push_to_hub=True was specified."
|
||||
)
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = not only_save_embeds
|
||||
@ -836,9 +992,17 @@ def do_textual_inversion_training(
|
||||
pipeline.save_pretrained(output_dir)
|
||||
# Save the newly trained embeddings
|
||||
save_path = os.path.join(output_dir, "learned_embeds.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path)
|
||||
save_progress(
|
||||
text_encoder,
|
||||
placeholder_token_id,
|
||||
accelerator,
|
||||
placeholder_token,
|
||||
save_path,
|
||||
)
|
||||
|
||||
if push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
repo.push_to_hub(
|
||||
commit_message="End of training", blocking=False, auto_lfs_prune=True
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
@ -92,14 +92,9 @@ test = ["pytest>6.0.0", "pytest-cov"]
|
||||
|
||||
[project.scripts]
|
||||
"configure_invokeai" = "ldm.invoke.configure_invokeai:main"
|
||||
"dream" = "ldm.invoke:CLI.main"
|
||||
"invoke" = "ldm.invoke:CLI.main"
|
||||
"legacy_api" = "scripts:legacy_api.main"
|
||||
"load_models" = "scripts:configure_invokeai.main"
|
||||
"merge_embeddings" = "scripts:merge_embeddings.main"
|
||||
"preload_models" = "ldm.invoke.configure_invokeai:main"
|
||||
"invoke" = "ldm.invoke.CLI:main"
|
||||
"textual_inversion" = "ldm.invoke.textual_inversion:main"
|
||||
"merge_models" = "ldm.invoke.merge_models:main"
|
||||
"merge_models" = "ldm.invoke.merge_diffusers:main" # note name munging
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://invoke-ai.github.io/InvokeAI/"
|
||||
|
@ -1,91 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ldm.invoke.globals import (Globals, global_cache_dir, global_config_file,
|
||||
global_set_root)
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
|
||||
parser = argparse.ArgumentParser(description="InvokeAI textual inversion training")
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help="Path to the invokeai runtime directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
required=True,
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Two to three model names to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merged_model_name",
|
||||
"--destination",
|
||||
dest="merged_model_name",
|
||||
type=str,
|
||||
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation",
|
||||
dest="interp",
|
||||
type=str,
|
||||
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
|
||||
default="weighted_sum",
|
||||
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Try to merge models even if they are incompatible with each other",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clobber",
|
||||
"--overwrite",
|
||||
dest='clobber',
|
||||
action="store_true",
|
||||
help="Overwrite the merged model if --merged_model_name already exists",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
global_set_root(args.root_dir)
|
||||
|
||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||
assert len(args.models) >= 1 and len(args.models) <= 3, "provide 2 or 3 models to merge"
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.models)
|
||||
print(
|
||||
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
)
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
||||
assert (args.clobber or args.merged_model_name not in model_manager.model_names()), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
# It seems that the merge pipeline is not honoring cache_dir, so we set the
|
||||
# HF_HOME environment variable here *before* we load diffusers.
|
||||
cache_dir = str(global_cache_dir("diffusers"))
|
||||
os.environ["HF_HOME"] = cache_dir
|
||||
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||
|
||||
try:
|
||||
merge_diffusion_models(**vars(args))
|
||||
print(f'>> Models merged into new model: "{args.merged_model_name}".')
|
||||
except Exception as e:
|
||||
print(f"** An error occurred while merging the pipelines: {str(e)}")
|
||||
print("** DETAILS:")
|
||||
print(traceback.format_exc())
|
||||
sys.exit(-1)
|
@ -1,217 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import npyscreen
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import argparse
|
||||
from ldm.invoke.globals import Globals, global_set_root, global_cache_dir, global_config_file
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" %(self.value, self.out_of)
|
||||
l = (len(str(self.out_of)))*2+4
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = FloatSlider
|
||||
|
||||
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
interpolations = ['weighted_sum',
|
||||
'sigmoid',
|
||||
'inv_sigmoid',
|
||||
'add_difference']
|
||||
|
||||
def __init__(self, parentApp, name):
|
||||
self.parentApp = parentApp
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
@property
|
||||
def model_manager(self):
|
||||
return self.parentApp.model_manager
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
self.model_names = self.get_model_names()
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
name="Select up to three models to merge",
|
||||
value=''
|
||||
)
|
||||
self.model1 = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='First Model:',
|
||||
values=self.model_names,
|
||||
value=0,
|
||||
max_height=len(self.model_names)+1
|
||||
)
|
||||
self.model2 = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Second Model:',
|
||||
values=self.model_names,
|
||||
value=1,
|
||||
max_height=len(self.model_names)+1
|
||||
)
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0,'None')
|
||||
self.model3 = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Third Model:',
|
||||
values=models_plus_none,
|
||||
value=0,
|
||||
max_height=len(self.model_names)+1,
|
||||
)
|
||||
|
||||
for m in [self.model1,self.model2,self.model3]:
|
||||
m.when_value_edited = self.models_changed
|
||||
|
||||
self.merge_method = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Merge Method:',
|
||||
values=self.interpolations,
|
||||
value=0,
|
||||
max_height=len(self.interpolations),
|
||||
)
|
||||
self.alpha = self.add_widget_intelligent(
|
||||
FloatTitleSlider,
|
||||
name='Weight (alpha) to assign to second and third models:',
|
||||
out_of=1,
|
||||
step=0.05,
|
||||
lowest=0,
|
||||
value=0.5,
|
||||
)
|
||||
self.force = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name='Force merge of incompatible models',
|
||||
value=False,
|
||||
)
|
||||
self.merged_model_name = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name='Name for merged model',
|
||||
value='',
|
||||
)
|
||||
|
||||
def models_changed(self):
|
||||
models = self.model1.values
|
||||
selected_model1 = self.model1.value[0]
|
||||
selected_model2 = self.model2.value[0]
|
||||
selected_model3 = self.model3.value[0]
|
||||
merged_model_name = f'{models[selected_model1]}+{models[selected_model2]}'
|
||||
self.merged_model_name.value = merged_model_name
|
||||
|
||||
if selected_model3 > 0:
|
||||
self.merge_method.values=['add_difference'],
|
||||
self.merged_model_name.value += f'+{models[selected_model3]}'
|
||||
else:
|
||||
self.merge_method.values=self.interpolations
|
||||
self.merge_method.value=0
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values() and self.check_for_overwrite():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.merge_arguments = self.marshall_arguments()
|
||||
npyscreen.notify('Starting the merge...')
|
||||
import ldm.invoke.merge_diffusers # this keeps the message up while diffusers loads
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def on_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def marshall_arguments(self)->dict:
|
||||
model_names = self.model_names
|
||||
models = [
|
||||
model_names[self.model1.value[0]],
|
||||
model_names[self.model2.value[0]],
|
||||
]
|
||||
if self.model3.value[0] > 0:
|
||||
models.append(model_names[self.model3.value[0]-1])
|
||||
|
||||
args = dict(
|
||||
models=models,
|
||||
alpha = self.alpha.value,
|
||||
interp = self.interpolations[self.merge_method.value[0]],
|
||||
force = self.force.value,
|
||||
merged_model_name = self.merged_model_name.value,
|
||||
)
|
||||
return args
|
||||
|
||||
def check_for_overwrite(self)->bool:
|
||||
model_out = self.merged_model_name.value
|
||||
if model_out not in self.model_names:
|
||||
return True
|
||||
else:
|
||||
return npyscreen.notify_yes_no(f'The chosen merged model destination, {model_out}, is already in use. Overwrite?')
|
||||
|
||||
def validate_field_values(self)->bool:
|
||||
bad_fields = []
|
||||
model_names = self.model_names
|
||||
selected_models = set((model_names[self.model1.value[0]],model_names[self.model2.value[0]]))
|
||||
if self.model3.value[0] > 0:
|
||||
selected_models.add(model_names[self.model3.value[0]-1])
|
||||
if len(selected_models) < 2:
|
||||
bad_fields.append(f'Please select two or three DIFFERENT models to compare. You selected {selected_models}')
|
||||
if len(bad_fields) > 0:
|
||||
message = 'The following problems were detected and must be corrected:'
|
||||
for problem in bad_fields:
|
||||
message += f'\n* {problem}'
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self)->List[str]:
|
||||
model_names = [name for name in self.model_manager.model_names() if self.model_manager.model_info(name).get('format') == 'diffusers']
|
||||
print(model_names)
|
||||
return sorted(model_names)
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
conf = OmegaConf.load(global_config_file())
|
||||
self.model_manager = ModelManager(conf,'cpu','float16') # precision doesn't really matter here
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main = self.addForm('MAIN', mergeModelsForm, name='Merge Models Settings')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='InvokeAI textual inversion training')
|
||||
parser.add_argument(
|
||||
'--root_dir','--root-dir',
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help='Path to the invokeai runtime directory',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
global_set_root(args.root_dir)
|
||||
|
||||
cache_dir = str(global_cache_dir('diffusers')) # because not clear the merge pipeline is honoring cache_dir
|
||||
os.environ['HF_HOME'] = cache_dir
|
||||
|
||||
mergeapp = Mergeapp()
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
args.update(cache_dir = cache_dir)
|
||||
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||
|
||||
try:
|
||||
merge_diffusion_models(**args)
|
||||
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
|
||||
except Exception as e:
|
||||
print(f'** An error occurred while merging the pipelines: {str(e)}')
|
||||
print('** DETAILS:')
|
||||
print(traceback.format_exc())
|
||||
sys.exit(-1)
|
Loading…
Reference in New Issue
Block a user