all files migrated; tweaks needed
1237
invokeai/frontend/CLI/CLI.py
Normal file
4
invokeai/frontend/CLI/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for invokeai.frontend.CLI
|
||||
'''
|
||||
from .CLI import main as invokeai_command_line_interface
|
455
invokeai/frontend/CLI/readline.py
Normal file
@ -0,0 +1,455 @@
|
||||
"""
|
||||
Readline helper functions for invoke.py.
|
||||
You may import the global singleton `completer` to get access to the
|
||||
completer object itself. This is useful when you want to autocomplete
|
||||
seeds:
|
||||
|
||||
from ldm.invoke.readline import completer
|
||||
completer.add_seed(18247566)
|
||||
completer.add_seed(9281839)
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import atexit
|
||||
from ...backend.args import Args
|
||||
from ...backend.globals import Globals
|
||||
from ...backend.stable_diffusion import HuggingFaceConceptsLibrary
|
||||
|
||||
# ---------------readline utilities---------------------
|
||||
try:
|
||||
import readline
|
||||
readline_available = True
|
||||
except (ImportError,ModuleNotFoundError) as e:
|
||||
print(f'** An error occurred when loading the readline module: {str(e)}')
|
||||
readline_available = False
|
||||
|
||||
IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF')
|
||||
WEIGHT_EXTENSIONS = ('.ckpt','.vae','.safetensors')
|
||||
TEXT_EXTENSIONS = ('.txt','.TXT')
|
||||
CONFIG_EXTENSIONS = ('.yaml','.yml')
|
||||
COMMANDS = (
|
||||
'--steps','-s',
|
||||
'--seed','-S',
|
||||
'--iterations','-n',
|
||||
'--width','-W','--height','-H',
|
||||
'--cfg_scale','-C',
|
||||
'--threshold',
|
||||
'--perlin',
|
||||
'--grid','-g',
|
||||
'--individual','-i',
|
||||
'--save_intermediates',
|
||||
'--init_img','-I',
|
||||
'--init_mask','-M',
|
||||
'--init_color',
|
||||
'--strength','-f',
|
||||
'--variants','-v',
|
||||
'--outdir','-o',
|
||||
'--sampler','-A','-m',
|
||||
'--embedding_path',
|
||||
'--device',
|
||||
'--grid','-g',
|
||||
'--facetool','-ft',
|
||||
'--facetool_strength','-G',
|
||||
'--codeformer_fidelity','-cf',
|
||||
'--upscale','-U',
|
||||
'-save_orig','--save_original',
|
||||
'--log_tokenization','-t',
|
||||
'--hires_fix',
|
||||
'--inpaint_replace','-r',
|
||||
'--png_compression','-z',
|
||||
'--text_mask','-tm',
|
||||
'--h_symmetry_time_pct',
|
||||
'--v_symmetry_time_pct',
|
||||
'!fix','!fetch','!replay','!history','!search','!clear',
|
||||
'!models','!switch','!import_model','!optimize_model','!convert_model','!edit_model','!del_model',
|
||||
'!mask','!triggers',
|
||||
)
|
||||
MODEL_COMMANDS = (
|
||||
'!switch',
|
||||
'!edit_model',
|
||||
'!del_model',
|
||||
)
|
||||
CKPT_MODEL_COMMANDS = (
|
||||
'!optimize_model',
|
||||
)
|
||||
WEIGHT_COMMANDS = (
|
||||
'!import_model',
|
||||
'!convert_model',
|
||||
)
|
||||
IMG_PATH_COMMANDS = (
|
||||
'--outdir[=\s]',
|
||||
)
|
||||
TEXT_PATH_COMMANDS=(
|
||||
'!replay',
|
||||
)
|
||||
IMG_FILE_COMMANDS=(
|
||||
'!fix',
|
||||
'!fetch',
|
||||
'!mask',
|
||||
'--init_img[=\s]','-I',
|
||||
'--init_mask[=\s]','-M',
|
||||
'--init_color[=\s]',
|
||||
'--embedding_path[=\s]',
|
||||
)
|
||||
|
||||
path_regexp = '(' + '|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
|
||||
weight_regexp = '(' + '|'.join(WEIGHT_COMMANDS) + ')\s*\S*$'
|
||||
text_regexp = '(' + '|'.join(TEXT_PATH_COMMANDS) + ')\s*\S*$'
|
||||
|
||||
class Completer(object):
|
||||
def __init__(self, options, models={}):
|
||||
self.options = sorted(options)
|
||||
self.models = models
|
||||
self.seeds = set()
|
||||
self.matches = list()
|
||||
self.default_dir = None
|
||||
self.linebuffer = None
|
||||
self.auto_history_active = True
|
||||
self.extensions = None
|
||||
self.concepts = None
|
||||
self.embedding_terms = set()
|
||||
return
|
||||
|
||||
def complete(self, text, state):
|
||||
'''
|
||||
Completes invoke command line.
|
||||
BUG: it doesn't correctly complete files that have spaces in the name.
|
||||
'''
|
||||
buffer = readline.get_line_buffer()
|
||||
|
||||
if state == 0:
|
||||
|
||||
# extensions defined, so go directly into path completion mode
|
||||
if self.extensions is not None:
|
||||
self.matches = self._path_completions(text, state, self.extensions)
|
||||
|
||||
# looking for an image file
|
||||
elif re.search(path_regexp,buffer):
|
||||
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
|
||||
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
|
||||
|
||||
# looking for a seed
|
||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||
self.matches= self._seed_completions(text,state)
|
||||
|
||||
# looking for an embedding concept
|
||||
elif re.search('<[\w-]*$',buffer):
|
||||
self.matches= self._concept_completions(text,state)
|
||||
|
||||
# looking for a model
|
||||
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state)
|
||||
|
||||
# looking for a ckpt model
|
||||
elif re.match('^'+'|'.join(CKPT_MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state, ckpt_only=True)
|
||||
|
||||
elif re.search(weight_regexp,buffer):
|
||||
self.matches = self._path_completions(
|
||||
text,
|
||||
state,
|
||||
WEIGHT_EXTENSIONS,
|
||||
default_dir=Globals.root,
|
||||
)
|
||||
|
||||
elif re.search(text_regexp,buffer):
|
||||
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
|
||||
|
||||
# This is the first time for this text, so build a match list.
|
||||
elif text:
|
||||
self.matches = [
|
||||
s for s in self.options if s and s.startswith(text)
|
||||
]
|
||||
else:
|
||||
self.matches = self.options[:]
|
||||
|
||||
# Return the state'th item from the match list,
|
||||
# if we have that many.
|
||||
try:
|
||||
response = self.matches[state]
|
||||
except IndexError:
|
||||
response = None
|
||||
return response
|
||||
|
||||
def complete_extensions(self, extensions:list):
|
||||
'''
|
||||
If called with a list of extensions, will force completer
|
||||
to do file path completions.
|
||||
'''
|
||||
self.extensions=extensions
|
||||
|
||||
def add_history(self,line):
|
||||
'''
|
||||
Pass thru to readline
|
||||
'''
|
||||
if not self.auto_history_active:
|
||||
readline.add_history(line)
|
||||
|
||||
def clear_history(self):
|
||||
'''
|
||||
Pass clear_history() thru to readline
|
||||
'''
|
||||
readline.clear_history()
|
||||
|
||||
def search_history(self,match:str):
|
||||
'''
|
||||
Like show_history() but only shows items that
|
||||
contain the match string.
|
||||
'''
|
||||
self.show_history(match)
|
||||
|
||||
def remove_history_item(self,pos):
|
||||
readline.remove_history_item(pos)
|
||||
|
||||
def add_seed(self, seed):
|
||||
'''
|
||||
Add a seed to the autocomplete list for display when -S is autocompleted.
|
||||
'''
|
||||
if seed is not None:
|
||||
self.seeds.add(str(seed))
|
||||
|
||||
def set_default_dir(self, path):
|
||||
self.default_dir=path
|
||||
|
||||
def set_options(self,options):
|
||||
self.options = options
|
||||
|
||||
def get_line(self,index):
|
||||
try:
|
||||
line = self.get_history_item(index)
|
||||
except IndexError:
|
||||
return None
|
||||
return line
|
||||
|
||||
def get_current_history_length(self):
|
||||
return readline.get_current_history_length()
|
||||
|
||||
def get_history_item(self,index):
|
||||
return readline.get_history_item(index)
|
||||
|
||||
def show_history(self,match=None):
|
||||
'''
|
||||
Print the session history using the pydoc pager
|
||||
'''
|
||||
import pydoc
|
||||
lines = list()
|
||||
h_len = self.get_current_history_length()
|
||||
if h_len < 1:
|
||||
print('<empty history>')
|
||||
return
|
||||
|
||||
for i in range(0,h_len):
|
||||
line = self.get_history_item(i+1)
|
||||
if match and match not in line:
|
||||
continue
|
||||
lines.append(f'[{i+1}] {line}')
|
||||
pydoc.pager('\n'.join(lines))
|
||||
|
||||
def set_line(self,line)->None:
|
||||
'''
|
||||
Set the default string displayed in the next line of input.
|
||||
'''
|
||||
self.linebuffer = line
|
||||
readline.redisplay()
|
||||
|
||||
def update_models(self,models:dict)->None:
|
||||
'''
|
||||
update our list of models
|
||||
'''
|
||||
self.models = models
|
||||
|
||||
def _seed_completions(self, text, state):
|
||||
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
||||
if m:
|
||||
switch = m.groups()[0]
|
||||
partial = m.groups()[1]
|
||||
else:
|
||||
switch = ''
|
||||
partial = text
|
||||
|
||||
matches = list()
|
||||
for s in self.seeds:
|
||||
if s.startswith(partial):
|
||||
matches.append(switch+s)
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def add_embedding_terms(self, terms:list[str]):
|
||||
self.embedding_terms = set(terms)
|
||||
if self.concepts:
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
|
||||
def _concept_completions(self, text, state):
|
||||
if self.concepts is None:
|
||||
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
|
||||
self.concepts = HuggingFaceConceptsLibrary()
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
else:
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
|
||||
partial = text[1:] # this removes the leading '<'
|
||||
if len(partial) == 0:
|
||||
return list(self.embedding_terms) # whole dump - think if user wants this!
|
||||
|
||||
matches = list()
|
||||
for concept in self.embedding_terms:
|
||||
if concept.startswith(partial):
|
||||
matches.append(f'<{concept}>')
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def _model_completions(self, text, state, ckpt_only=False):
|
||||
m = re.search('(!switch\s+)(\w*)',text)
|
||||
if m:
|
||||
switch = m.groups()[0]
|
||||
partial = m.groups()[1]
|
||||
else:
|
||||
switch = ''
|
||||
partial = text
|
||||
matches = list()
|
||||
for s in self.models:
|
||||
format = self.models[s]['format']
|
||||
if format == 'vae':
|
||||
continue
|
||||
if ckpt_only and format != 'ckpt':
|
||||
continue
|
||||
if s.startswith(partial):
|
||||
matches.append(switch+s)
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def _pre_input_hook(self):
|
||||
if self.linebuffer:
|
||||
readline.insert_text(self.linebuffer)
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
def _path_completions(self, text, state, extensions, shortcut_ok=True, default_dir:str=''):
|
||||
# separate the switch from the partial path
|
||||
match = re.search('^(-\w|--\w+=?)(.*)',text)
|
||||
if match is None:
|
||||
switch = None
|
||||
partial_path = text
|
||||
else:
|
||||
switch,partial_path = match.groups()
|
||||
|
||||
partial_path = partial_path.lstrip()
|
||||
|
||||
matches = list()
|
||||
path = os.path.expanduser(partial_path)
|
||||
|
||||
if os.path.isdir(path):
|
||||
dir = path
|
||||
elif os.path.dirname(path) != '':
|
||||
dir = os.path.dirname(path)
|
||||
else:
|
||||
dir = default_dir if os.path.exists(default_dir) else ''
|
||||
path= os.path.join(dir,path)
|
||||
|
||||
dir_list = os.listdir(dir or '.')
|
||||
if shortcut_ok and os.path.exists(self.default_dir) and dir=='':
|
||||
dir_list += os.listdir(self.default_dir)
|
||||
|
||||
for node in dir_list:
|
||||
if node.startswith('.') and len(node) > 1:
|
||||
continue
|
||||
full_path = os.path.join(dir, node)
|
||||
|
||||
if not (node.endswith(extensions) or os.path.isdir(full_path)):
|
||||
continue
|
||||
|
||||
if path and not full_path.startswith(path):
|
||||
continue
|
||||
|
||||
if switch is None:
|
||||
match_path = os.path.join(dir,node)
|
||||
matches.append(match_path+'/' if os.path.isdir(full_path) else match_path)
|
||||
elif os.path.isdir(full_path):
|
||||
matches.append(
|
||||
switch+os.path.join(os.path.dirname(full_path), node) + '/'
|
||||
)
|
||||
elif node.endswith(extensions):
|
||||
matches.append(
|
||||
switch+os.path.join(os.path.dirname(full_path), node)
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
class DummyCompleter(Completer):
|
||||
def __init__(self,options):
|
||||
super().__init__(options)
|
||||
self.history = list()
|
||||
|
||||
def add_history(self,line):
|
||||
self.history.append(line)
|
||||
|
||||
def clear_history(self):
|
||||
self.history = list()
|
||||
|
||||
def get_current_history_length(self):
|
||||
return len(self.history)
|
||||
|
||||
def get_history_item(self,index):
|
||||
return self.history[index-1]
|
||||
|
||||
def remove_history_item(self,index):
|
||||
return self.history.pop(index-1)
|
||||
|
||||
def set_line(self,line):
|
||||
print(f'# {line}')
|
||||
|
||||
def generic_completer(commands:list)->Completer:
|
||||
if readline_available:
|
||||
completer = Completer(commands,[])
|
||||
readline.set_completer(completer.complete)
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(' ')
|
||||
readline.parse_and_bind('tab: complete')
|
||||
readline.parse_and_bind('set print-completions-horizontally off')
|
||||
readline.parse_and_bind('set page-completions on')
|
||||
readline.parse_and_bind('set skip-completed-text on')
|
||||
readline.parse_and_bind('set show-all-if-ambiguous on')
|
||||
else:
|
||||
completer = DummyCompleter(commands)
|
||||
return completer
|
||||
|
||||
def get_completer(opt:Args, models=[])->Completer:
|
||||
if readline_available:
|
||||
completer = Completer(COMMANDS,models)
|
||||
|
||||
readline.set_completer(
|
||||
completer.complete
|
||||
)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
try:
|
||||
readline.set_auto_history(False)
|
||||
completer.auto_history_active = False
|
||||
except:
|
||||
completer.auto_history_active = True
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(' ')
|
||||
readline.parse_and_bind('tab: complete')
|
||||
readline.parse_and_bind('set print-completions-horizontally off')
|
||||
readline.parse_and_bind('set page-completions on')
|
||||
readline.parse_and_bind('set skip-completed-text on')
|
||||
readline.parse_and_bind('set show-all-if-ambiguous on')
|
||||
|
||||
outdir = os.path.expanduser(opt.outdir)
|
||||
if os.path.isabs(outdir):
|
||||
histfile = os.path.join(outdir,'.invoke_history')
|
||||
else:
|
||||
histfile = os.path.join(Globals.root, outdir, '.invoke_history')
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError: # file likely corrupted
|
||||
newname = f'{histfile}.old'
|
||||
print(f'## Your history file {histfile} couldn\'t be loaded and may be corrupted. Renaming it to {newname}')
|
||||
os.replace(histfile,newname)
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
|
||||
else:
|
||||
completer = DummyCompleter(COMMANDS)
|
||||
return completer
|
3
invokeai/frontend/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
'''
|
||||
Initialization file for invokeai.frontend
|
||||
'''
|
7
invokeai/frontend/config/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
'''
|
||||
Initialization file for invokeai.frontend.config
|
||||
'''
|
||||
from .model_install import main as invokeai_model_install
|
||||
from .invokeai_configure import main as invokeai_configure
|
||||
from .invokeai_update import main as invokeai_update
|
||||
|
4
invokeai/frontend/config/invokeai_configure.py
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Wrapper for invokeai.backend.configure.invokeai_configure
|
||||
'''
|
||||
from ...backend.config.invokeai_configure import main
|
88
invokeai/frontend/config/invokeai_update.py
Normal file
@ -0,0 +1,88 @@
|
||||
'''
|
||||
Minimalist updater script. Prompts user for the tag or branch to update to and runs
|
||||
pip install <path_to_git_source>.
|
||||
'''
|
||||
import os
|
||||
import platform
|
||||
import requests
|
||||
from rich import box, print
|
||||
from rich.console import Console, Group, group
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt
|
||||
from rich.style import Style
|
||||
from rich.syntax import Syntax
|
||||
from rich.text import Text
|
||||
|
||||
from invokeai.version import __version__
|
||||
|
||||
INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive"
|
||||
INVOKE_AI_REL="https://api.github.com/repos/invoke-ai/InvokeAI/releases"
|
||||
|
||||
OS = platform.uname().system
|
||||
ARCH = platform.uname().machine
|
||||
|
||||
if OS == "Windows":
|
||||
# Windows terminals look better without a background colour
|
||||
console = Console(style=Style(color="grey74"))
|
||||
else:
|
||||
console = Console(style=Style(color="grey74", bgcolor="grey19"))
|
||||
|
||||
def get_versions()->dict:
|
||||
return requests.get(url=INVOKE_AI_REL).json()
|
||||
|
||||
def welcome(versions: dict):
|
||||
|
||||
@group()
|
||||
def text():
|
||||
yield f'InvokeAI Version: [bold yellow]{__version__}'
|
||||
yield ''
|
||||
yield 'This script will update InvokeAI to the latest release, or to a development version of your choice.'
|
||||
yield ''
|
||||
yield '[bold yellow]Options:'
|
||||
yield f'''[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic])
|
||||
[2] Update to the bleeding-edge development version ([italic]main[/italic])
|
||||
[3] Manually enter the tag or branch name you wish to update'''
|
||||
|
||||
console.rule()
|
||||
print(
|
||||
Panel(
|
||||
title="[bold wheat1]InvokeAI Updater",
|
||||
renderable=text(),
|
||||
box=box.DOUBLE,
|
||||
expand=True,
|
||||
padding=(1, 2),
|
||||
style=Style(bgcolor="grey23", color="orange1"),
|
||||
subtitle=f"[bold grey39]{OS}-{ARCH}",
|
||||
)
|
||||
)
|
||||
console.line()
|
||||
|
||||
def main():
|
||||
versions = get_versions()
|
||||
welcome(versions)
|
||||
|
||||
tag = None
|
||||
choice = Prompt.ask('Choice:',choices=['1','2','3'],default='1')
|
||||
|
||||
if choice=='1':
|
||||
tag = versions[0]['tag_name']
|
||||
elif choice=='2':
|
||||
tag = 'main'
|
||||
elif choice=='3':
|
||||
tag = Prompt.ask('Enter an InvokeAI tag or branch name')
|
||||
|
||||
print(f':crossed_fingers: Upgrading to [yellow]{tag}[/yellow]')
|
||||
cmd = f'pip install {INVOKE_AI_SRC}/{tag}.zip --use-pep517'
|
||||
print('')
|
||||
print('')
|
||||
if os.system(cmd)==0:
|
||||
print(f':heavy_check_mark: Upgrade successful')
|
||||
else:
|
||||
print(f':exclamation: [bold red]Upgrade failed[/red bold]')
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
504
invokeai/frontend/config/model_install.py
Normal file
@ -0,0 +1,504 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
# Before running stable-diffusion on an internet-isolated machine,
|
||||
# run this script from one with internet connectivity. The
|
||||
# two machines must share a common .cache directory.
|
||||
|
||||
"""
|
||||
This is the npyscreen frontend to the model installation application.
|
||||
The work is actually done in backend code in model_install_backend.py.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
from shutil import get_terminal_size
|
||||
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.globals import Globals, global_config_dir
|
||||
from ...backend.config.model_install_backend import (Dataset_path, default_config_file,
|
||||
default_dataset, get_root,
|
||||
install_requested_models,
|
||||
recommended_datasets,
|
||||
)
|
||||
from .widgets import (MultiSelectColumns, TextBox,
|
||||
OffsetButtonPress, CenteredTitleText,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
|
||||
# minimum size for the UI
|
||||
MIN_COLS = 120
|
||||
MIN_LINES = 45
|
||||
|
||||
class addModelsForm(npyscreen.FormMultiPage):
|
||||
# for responsive resizing - disabled
|
||||
#FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||
|
||||
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
||||
self.multipage = multipage
|
||||
self.initial_models = OmegaConf.load(Dataset_path)
|
||||
try:
|
||||
self.existing_models = OmegaConf.load(default_config_file())
|
||||
except:
|
||||
self.existing_models = dict()
|
||||
self.starter_model_list = [
|
||||
x for x in list(self.initial_models.keys()) if x not in self.existing_models
|
||||
]
|
||||
self.installed_models = dict()
|
||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords)
|
||||
|
||||
def create(self):
|
||||
window_width, window_height = get_terminal_size()
|
||||
starter_model_labels = self._get_starter_model_labels()
|
||||
recommended_models = [
|
||||
x
|
||||
for x in self.starter_model_list
|
||||
if self.initial_models[x].get("recommended", False)
|
||||
]
|
||||
self.installed_models = sorted(
|
||||
[x for x in list(self.initial_models.keys()) if x in self.existing_models]
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields,",
|
||||
editable=False,
|
||||
color='CAUTION',
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use cursor arrows to make a selection, and space to toggle checkboxes.",
|
||||
editable=False,
|
||||
color='CAUTION'
|
||||
)
|
||||
self.nextrely += 1
|
||||
if len(self.installed_models) > 0:
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="== INSTALLED STARTER MODELS ==",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="Currently installed starter models. Uncheck to delete:",
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
columns = self._get_columns()
|
||||
self.previously_installed_models = self.add_widget_intelligent(
|
||||
MultiSelectColumns,
|
||||
columns=columns,
|
||||
values=self.installed_models,
|
||||
value=[x for x in range(0, len(self.installed_models))],
|
||||
max_height=1 + len(self.installed_models) // columns,
|
||||
relx=4,
|
||||
slow_scroll=True,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.purge_deleted = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Purge deleted models from disk",
|
||||
value=False,
|
||||
scroll_exit=True,
|
||||
relx=4,
|
||||
)
|
||||
self.nextrely += 1
|
||||
if len(self.starter_model_list) > 0:
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="== STARTER MODELS (recommended ones selected) ==",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="Select from a starter set of Stable Diffusion models from HuggingFace.",
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
# if user has already installed some initial models, then don't patronize them
|
||||
# by showing more recommendations
|
||||
show_recommended = not self.existing_models
|
||||
self.models_selected = self.add_widget_intelligent(
|
||||
npyscreen.MultiSelect,
|
||||
name="Install Starter Models",
|
||||
values=starter_model_labels,
|
||||
value=[
|
||||
self.starter_model_list.index(x)
|
||||
for x in self.starter_model_list
|
||||
if show_recommended and x in recommended_models
|
||||
],
|
||||
max_height=len(starter_model_labels) + 1,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name='== IMPORT LOCAL AND REMOTE MODELS ==',
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
|
||||
for line in [
|
||||
"In the box below, enter URLs, file paths, or HuggingFace repository IDs.",
|
||||
"Separate model names by lines or whitespace (Use shift-control-V to paste):",
|
||||
]:
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name=line,
|
||||
editable=False,
|
||||
labelColor="CONTROL",
|
||||
relx = 4,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.import_model_paths = self.add_widget_intelligent(
|
||||
TextBox,
|
||||
max_height=7,
|
||||
scroll_exit=True,
|
||||
editable=True,
|
||||
relx=4
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.show_directory_fields = self.add_widget_intelligent(
|
||||
npyscreen.FormControlCheckbox,
|
||||
name="Select a directory for models to import",
|
||||
value=False,
|
||||
)
|
||||
self.autoload_directory = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Directory (<tab> autocompletes):",
|
||||
select_dir=True,
|
||||
must_exist=True,
|
||||
use_two_lines=False,
|
||||
labelColor="DANGER",
|
||||
begin_entry_at=34,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.autoscan_on_startup = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Scan this directory each time InvokeAI starts for new models to import",
|
||||
value=False,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.convert_models = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
|
||||
values=["Keep original format", "Convert to diffusers"],
|
||||
value=0,
|
||||
begin_entry_at=4,
|
||||
max_height=4,
|
||||
hidden=True, # will appear when imported models box is edited
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.cancel = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name="CANCEL",
|
||||
rely=-3,
|
||||
when_pressed_function=self.on_cancel,
|
||||
)
|
||||
done_label = "DONE"
|
||||
back_label = "BACK"
|
||||
button_length = len(done_label)
|
||||
button_offset = 0
|
||||
if self.multipage:
|
||||
button_length += len(back_label) + 1
|
||||
button_offset += len(back_label) + 1
|
||||
self.back_button = self.add_widget_intelligent(
|
||||
OffsetButtonPress,
|
||||
name=back_label,
|
||||
relx=(window_width - button_length) // 2,
|
||||
offset=-3,
|
||||
rely=-3,
|
||||
when_pressed_function=self.on_back,
|
||||
)
|
||||
self.ok_button = self.add_widget_intelligent(
|
||||
OffsetButtonPress,
|
||||
name=done_label,
|
||||
offset=+3,
|
||||
relx=button_offset + 1 + (window_width - button_length) // 2,
|
||||
rely=-3,
|
||||
when_pressed_function=self.on_ok,
|
||||
)
|
||||
|
||||
for i in [self.autoload_directory, self.autoscan_on_startup]:
|
||||
self.show_directory_fields.addVisibleWhenSelected(i)
|
||||
|
||||
self.show_directory_fields.when_value_edited = self._clear_scan_directory
|
||||
self.import_model_paths.when_value_edited = self._show_hide_convert
|
||||
self.autoload_directory.when_value_edited = self._show_hide_convert
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
if hasattr(self,'models_selected'):
|
||||
self.models_selected.values = self._get_starter_model_labels()
|
||||
|
||||
def _clear_scan_directory(self):
|
||||
if not self.show_directory_fields.value:
|
||||
self.autoload_directory.value = ""
|
||||
|
||||
def _show_hide_convert(self):
|
||||
model_paths = self.import_model_paths.value or ""
|
||||
autoload_directory = self.autoload_directory.value or ""
|
||||
self.convert_models.hidden = (
|
||||
len(model_paths) == 0 and len(autoload_directory) == 0
|
||||
)
|
||||
|
||||
def _get_starter_model_labels(self) -> List[str]:
|
||||
window_width, window_height = get_terminal_size()
|
||||
label_width = 25
|
||||
checkbox_width = 4
|
||||
spacing_width = 2
|
||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||
im = self.initial_models
|
||||
names = self.starter_model_list
|
||||
descriptions = [
|
||||
im[x].description[0 : description_width - 3] + "..."
|
||||
if len(im[x].description) > description_width
|
||||
else im[x].description
|
||||
for x in names
|
||||
]
|
||||
return [
|
||||
f"%-{label_width}s %s" % (names[x], descriptions[x])
|
||||
for x in range(0, len(names))
|
||||
]
|
||||
|
||||
def _get_columns(self) -> int:
|
||||
window_width, window_height = get_terminal_size()
|
||||
cols = (
|
||||
4
|
||||
if window_width > 240
|
||||
else 3
|
||||
if window_width > 160
|
||||
else 2
|
||||
if window_width > 80
|
||||
else 1
|
||||
)
|
||||
return min(cols, len(self.installed_models))
|
||||
|
||||
def on_ok(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.user_cancelled = False
|
||||
self.marshall_arguments()
|
||||
|
||||
def on_back(self):
|
||||
self.parentApp.switchFormPrevious()
|
||||
self.editing = False
|
||||
|
||||
def on_cancel(self):
|
||||
if npyscreen.notify_yes_no(
|
||||
"Are you sure you want to cancel?\nYou may re-run this script later using the invoke.sh or invoke.bat command.\n"
|
||||
):
|
||||
self.parentApp.setNextForm(None)
|
||||
self.parentApp.user_cancelled = True
|
||||
self.editing = False
|
||||
|
||||
def marshall_arguments(self):
|
||||
"""
|
||||
Assemble arguments and store as attributes of the application:
|
||||
.starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml
|
||||
True => Install
|
||||
False => Remove
|
||||
.scan_directory: Path to a directory of models to scan and import
|
||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
|
||||
"""
|
||||
# we're using a global here rather than storing the result in the parentapp
|
||||
# due to some bug in npyscreen that is causing attributes to be lost
|
||||
selections = self.parentApp.user_selections
|
||||
|
||||
# starter models to install/remove
|
||||
if hasattr(self,'models_selected'):
|
||||
starter_models = dict(
|
||||
map(
|
||||
lambda x: (self.starter_model_list[x], True), self.models_selected.value
|
||||
)
|
||||
)
|
||||
else:
|
||||
starter_models = dict()
|
||||
selections.purge_deleted_models = False
|
||||
if hasattr(self, "previously_installed_models"):
|
||||
unchecked = [
|
||||
self.previously_installed_models.values[x]
|
||||
for x in range(0, len(self.previously_installed_models.values))
|
||||
if x not in self.previously_installed_models.value
|
||||
]
|
||||
starter_models.update(map(lambda x: (x, False), unchecked))
|
||||
selections.purge_deleted_models = self.purge_deleted.value
|
||||
selections.starter_models = starter_models
|
||||
|
||||
# load directory and whether to scan on startup
|
||||
if self.show_directory_fields.value:
|
||||
selections.scan_directory = self.autoload_directory.value
|
||||
selections.autoscan_on_startup = self.autoscan_on_startup.value
|
||||
else:
|
||||
selections.scan_directory = None
|
||||
selections.autoscan_on_startup = False
|
||||
|
||||
# URLs and the like
|
||||
selections.import_model_paths = self.import_model_paths.value.split()
|
||||
selections.convert_to_diffusers = self.convert_models.value[0] == 1
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.user_cancelled = False
|
||||
self.user_selections = Namespace(
|
||||
starter_models=None,
|
||||
purge_deleted_models=False,
|
||||
scan_directory=None,
|
||||
autoscan_on_startup=None,
|
||||
import_model_paths=None,
|
||||
convert_to_diffusers=None,
|
||||
)
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main_form = self.addForm(
|
||||
"MAIN", addModelsForm, name="Install Stable Diffusion Models"
|
||||
)
|
||||
|
||||
# --------------------------------------------------------
|
||||
def process_and_execute(opt: Namespace, selections: Namespace):
|
||||
models_to_remove = [
|
||||
x for x in selections.starter_models if not selections.starter_models[x]
|
||||
]
|
||||
models_to_install = [
|
||||
x for x in selections.starter_models if selections.starter_models[x]
|
||||
]
|
||||
directory_to_scan = selections.scan_directory
|
||||
scan_at_startup = selections.autoscan_on_startup
|
||||
potential_models_to_install = selections.import_model_paths
|
||||
convert_to_diffusers = selections.convert_to_diffusers
|
||||
|
||||
install_requested_models(
|
||||
install_initial_models=models_to_install,
|
||||
remove_models=models_to_remove,
|
||||
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
||||
external_models=potential_models_to_install,
|
||||
scan_at_startup=scan_at_startup,
|
||||
convert_to_diffusers=convert_to_diffusers,
|
||||
precision="float32"
|
||||
if opt.full_precision
|
||||
else choose_precision(torch.device(choose_torch_device())),
|
||||
purge_deleted=selections.purge_deleted_models,
|
||||
config_file_path=Path(opt.config_file) if opt.config_file else None,
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def select_and_download_models(opt: Namespace):
|
||||
precision = (
|
||||
"float32"
|
||||
if opt.full_precision
|
||||
else choose_precision(torch.device(choose_torch_device()))
|
||||
)
|
||||
if opt.default_only:
|
||||
install_requested_models(
|
||||
install_initial_models=default_dataset(),
|
||||
precision=precision,
|
||||
)
|
||||
elif opt.yes_to_all:
|
||||
install_requested_models(
|
||||
install_initial_models=recommended_datasets(),
|
||||
precision=precision,
|
||||
)
|
||||
else:
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
||||
installApp = AddModelApplication()
|
||||
installApp.run()
|
||||
|
||||
if not installApp.user_cancelled:
|
||||
process_and_execute(opt, installApp.user_selections)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
parser.add_argument(
|
||||
"--full-precision",
|
||||
dest="full_precision",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use 32-bit weights instead of faster 16-bit weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yes",
|
||||
"-y",
|
||||
dest="yes_to_all",
|
||||
action="store_true",
|
||||
help='answer "yes" to all prompts',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--default_only",
|
||||
action="store_true",
|
||||
help="only install the default model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
"-c",
|
||||
dest="config_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to configuration file to create",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
dest="root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to root of install directory",
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
# setting a global here
|
||||
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
||||
|
||||
if not global_config_dir().exists():
|
||||
print(
|
||||
">> Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||
)
|
||||
import ldm.invoke.config.invokeai_configure
|
||||
|
||||
ldm.invoke.config.invokeai_configure.main()
|
||||
sys.exit(0)
|
||||
|
||||
try:
|
||||
select_and_download_models(opt)
|
||||
except AssertionError as e:
|
||||
print(str(e))
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
except widget.NotEnoughSpaceForWidget as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
print(
|
||||
"** Insufficient vertical space for the interface. Please make your window taller and try again"
|
||||
)
|
||||
elif str(e).startswith("addwstr"):
|
||||
print(
|
||||
"** Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||
)
|
||||
|
||||
# -------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main()
|
164
invokeai/frontend/config/widgets.py
Normal file
@ -0,0 +1,164 @@
|
||||
'''
|
||||
Widget class definitions used by model_select.py, merge_diffusers.py and textual_inversion.py
|
||||
'''
|
||||
import math
|
||||
import platform
|
||||
import npyscreen
|
||||
import os
|
||||
import sys
|
||||
import curses
|
||||
import struct
|
||||
|
||||
from shutil import get_terminal_size
|
||||
|
||||
# -------------------------------------
|
||||
def set_terminal_size(columns: int, lines: int):
|
||||
OS = platform.uname().system
|
||||
if OS=="Windows":
|
||||
os.system(f'mode con: cols={columns} lines={lines}')
|
||||
elif OS in ['Darwin', 'Linux']:
|
||||
import termios
|
||||
import fcntl
|
||||
winsize = struct.pack("HHHH", lines, columns, 0, 0)
|
||||
fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize)
|
||||
sys.stdout.write("\x1b[8;{rows};{cols}t".format(rows=lines, cols=columns))
|
||||
sys.stdout.flush()
|
||||
|
||||
def set_min_terminal_size(min_cols: int, min_lines: int):
|
||||
# make sure there's enough room for the ui
|
||||
term_cols, term_lines = get_terminal_size()
|
||||
cols = max(term_cols, min_cols)
|
||||
lines = max(term_lines, min_lines)
|
||||
set_terminal_size(cols,lines)
|
||||
|
||||
class IntSlider(npyscreen.Slider):
|
||||
def translate_value(self):
|
||||
stri = "%2d / %2d" % (self.value, self.out_of)
|
||||
l = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
# -------------------------------------
|
||||
class CenteredTitleText(npyscreen.TitleText):
|
||||
def __init__(self,*args,**keywords):
|
||||
super().__init__(*args,**keywords)
|
||||
self.resize()
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
label = self.name
|
||||
self.relx = (maxx - len(label)) // 2
|
||||
|
||||
# -------------------------------------
|
||||
class CenteredButtonPress(npyscreen.ButtonPress):
|
||||
def resize(self):
|
||||
super().resize()
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
label = self.name
|
||||
self.relx = (maxx - len(label)) // 2
|
||||
|
||||
# -------------------------------------
|
||||
class OffsetButtonPress(npyscreen.ButtonPress):
|
||||
def __init__(self, screen, offset=0, *args, **keywords):
|
||||
super().__init__(screen, *args, **keywords)
|
||||
self.offset = offset
|
||||
|
||||
def resize(self):
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
width = len(self.name)
|
||||
self.relx = self.offset + (maxx - width) // 2
|
||||
|
||||
class IntTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = IntSlider
|
||||
|
||||
class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
|
||||
l = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = FloatSlider
|
||||
|
||||
class MultiSelectColumns(npyscreen.MultiSelect):
|
||||
def __init__(self, screen, columns: int=1, values: list=[], **keywords):
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
super().__init__(screen,values=values, **keywords)
|
||||
|
||||
def make_contained_widgets(self):
|
||||
self._my_widgets = []
|
||||
column_width = self.width // self.columns
|
||||
for h in range(self.value_cnt):
|
||||
self._my_widgets.append(
|
||||
self._contained_widgets(self.parent,
|
||||
rely=self.rely + (h % self.rows) * self._contained_widget_height,
|
||||
relx=self.relx + (h // self.rows) * column_width,
|
||||
max_width=column_width,
|
||||
max_height=self.__class__._contained_widget_height,
|
||||
)
|
||||
)
|
||||
|
||||
def set_up_handlers(self):
|
||||
super().set_up_handlers()
|
||||
self.handlers.update({
|
||||
curses.KEY_UP: self.h_cursor_line_left,
|
||||
curses.KEY_DOWN: self.h_cursor_line_right,
|
||||
}
|
||||
)
|
||||
def h_cursor_line_down(self, ch):
|
||||
self.cursor_line += self.rows
|
||||
if self.cursor_line >= len(self.values):
|
||||
if self.scroll_exit:
|
||||
self.cursor_line = len(self.values)-self.rows
|
||||
self.h_exit_down(ch)
|
||||
return True
|
||||
else:
|
||||
self.cursor_line -= self.rows
|
||||
return True
|
||||
|
||||
def h_cursor_line_up(self, ch):
|
||||
self.cursor_line -= self.rows
|
||||
if self.cursor_line < 0:
|
||||
if self.scroll_exit:
|
||||
self.cursor_line = 0
|
||||
self.h_exit_up(ch)
|
||||
else:
|
||||
self.cursor_line = 0
|
||||
|
||||
def h_cursor_line_left(self,ch):
|
||||
super().h_cursor_line_up(ch)
|
||||
|
||||
def h_cursor_line_right(self,ch):
|
||||
super().h_cursor_line_down(ch)
|
||||
|
||||
class TextBox(npyscreen.MultiLineEdit):
|
||||
def update(self, clear=True):
|
||||
if clear: self.clear()
|
||||
|
||||
HEIGHT = self.height
|
||||
WIDTH = self.width
|
||||
# draw box.
|
||||
self.parent.curses_pad.hline(self.rely, self.relx, curses.ACS_HLINE, WIDTH)
|
||||
self.parent.curses_pad.hline(self.rely + HEIGHT, self.relx, curses.ACS_HLINE, WIDTH)
|
||||
self.parent.curses_pad.vline(self.rely, self.relx, curses.ACS_VLINE, self.height)
|
||||
self.parent.curses_pad.vline(self.rely, self.relx+WIDTH, curses.ACS_VLINE, HEIGHT)
|
||||
|
||||
# draw corners
|
||||
self.parent.curses_pad.addch(self.rely, self.relx, curses.ACS_ULCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely, self.relx+WIDTH, curses.ACS_URCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely+HEIGHT, self.relx, curses.ACS_LLCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely+HEIGHT, self.relx+WIDTH, curses.ACS_LRCORNER, )
|
||||
|
||||
# fool our superclass into thinking drawing area is smaller - this is really hacky but it seems to work
|
||||
(relx,rely,height,width) = (self.relx, self.rely, self.height, self.width)
|
||||
self.relx += 1
|
||||
self.rely += 1
|
||||
self.height -= 1
|
||||
self.width -= 1
|
||||
super().update(clear=False)
|
||||
(self.relx,self.rely,self.height,self.width) = (relx, rely, height, width)
|
4
invokeai/frontend/merge/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for invokeai.frontend.merge
|
||||
'''
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers
|
467
invokeai/frontend/merge/merge_diffusers.py
Normal file
@ -0,0 +1,467 @@
|
||||
"""
|
||||
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
|
||||
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
import argparse
|
||||
import curses
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import npyscreen
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ...frontend.config.widgets import FloatTitleSlider
|
||||
from ...backend.globals import (Globals, global_cache_dir, global_config_file,
|
||||
global_models_dir, global_set_root)
|
||||
from ...backend.model_management import ModelManager
|
||||
|
||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
||||
|
||||
def merge_diffusion_models(
|
||||
model_ids_or_paths: List[Union[str, Path]],
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return merged_pipe
|
||||
|
||||
|
||||
def merge_diffusion_models_and_commit(
|
||||
models: List["str"],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
models - up to three models, designated by their InvokeAI models.yaml model name
|
||||
merged_model_name = name for new model
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
config_file = global_config_file()
|
||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||
for mod in models:
|
||||
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
|
||||
assert (
|
||||
model_manager.model_info(mod).get("format", None) == "diffusers"
|
||||
), f"** {mod} is not a diffusers model. It must be optimized before merging."
|
||||
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
||||
|
||||
merged_pipe = merge_diffusion_models(
|
||||
model_ids_or_paths, alpha, interp, force, **kwargs
|
||||
)
|
||||
dump_path = global_models_dir() / DEST_MERGED_MODEL_DIR
|
||||
|
||||
os.makedirs(dump_path, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
import_args = dict(
|
||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
if vae := model_manager.config[models[0]].get("vae", None):
|
||||
print(f">> Using configured VAE assigned to {models[0]}")
|
||||
import_args.update(vae=vae)
|
||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
||||
model_manager.commit(config_file)
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help="Path to the invokeai runtime directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--front_end",
|
||||
"--gui",
|
||||
dest="front_end",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Two to three model names to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merged_model_name",
|
||||
"--destination",
|
||||
dest="merged_model_name",
|
||||
type=str,
|
||||
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation",
|
||||
dest="interp",
|
||||
type=str,
|
||||
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
|
||||
default="weighted_sum",
|
||||
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Try to merge models even if they are incompatible with each other",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clobber",
|
||||
"--overwrite",
|
||||
dest="clobber",
|
||||
action="store_true",
|
||||
help="Overwrite the merged model if --merged_model_name already exists",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# ------------------------- GUI HERE -------------------------
|
||||
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"]
|
||||
|
||||
def __init__(self, parentApp, name):
|
||||
self.parentApp = parentApp
|
||||
self.ALLOW_RESIZE = True
|
||||
self.FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
@property
|
||||
def model_manager(self):
|
||||
return self.parentApp.model_manager
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
|
||||
self.model_names = self.get_model_names()
|
||||
max_width = max([len(x) for x in self.model_names])
|
||||
max_width += 6
|
||||
horizontal_layout = max_width * 3 < window_width
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
color="CONTROL",
|
||||
value="Select two models to merge and optionally a third.",
|
||||
editable=False,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
color="CONTROL",
|
||||
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
|
||||
editable=False,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 1",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
rely=4 if horizontal_layout else None,
|
||||
)
|
||||
self.model1 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
values=self.model_names,
|
||||
value=0,
|
||||
max_height=len(self.model_names),
|
||||
max_width=max_width,
|
||||
scroll_exit=True,
|
||||
rely=5,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 2",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
relx=max_width + 3 if horizontal_layout else None,
|
||||
rely=4 if horizontal_layout else None,
|
||||
)
|
||||
self.model2 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
name="(2)",
|
||||
values=self.model_names,
|
||||
value=1,
|
||||
max_height=len(self.model_names),
|
||||
max_width=max_width,
|
||||
relx=max_width + 3 if horizontal_layout else None,
|
||||
rely=5 if horizontal_layout else None,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 3",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||
rely=4 if horizontal_layout else None,
|
||||
)
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
self.model3 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
name="(3)",
|
||||
values=models_plus_none,
|
||||
value=0,
|
||||
max_height=len(self.model_names) + 1,
|
||||
max_width=max_width,
|
||||
scroll_exit=True,
|
||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||
rely=5 if horizontal_layout else None,
|
||||
)
|
||||
for m in [self.model1, self.model2, self.model3]:
|
||||
m.when_value_edited = self.models_changed
|
||||
self.merged_model_name = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Name for merged model:",
|
||||
labelColor="CONTROL",
|
||||
value="",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.force = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Force merge of incompatible models",
|
||||
labelColor="CONTROL",
|
||||
value=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.merge_method = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Merge Method:",
|
||||
values=self.interpolations,
|
||||
value=0,
|
||||
labelColor="CONTROL",
|
||||
max_height=len(self.interpolations) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.alpha = self.add_widget_intelligent(
|
||||
FloatTitleSlider,
|
||||
name="Weight (alpha) to assign to second and third models:",
|
||||
out_of=1.0,
|
||||
step=0.01,
|
||||
lowest=0,
|
||||
value=0.5,
|
||||
labelColor="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.model1.editing = True
|
||||
|
||||
def models_changed(self):
|
||||
models = self.model1.values
|
||||
selected_model1 = self.model1.value[0]
|
||||
selected_model2 = self.model2.value[0]
|
||||
selected_model3 = self.model3.value[0]
|
||||
merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}"
|
||||
self.merged_model_name.value = merged_model_name
|
||||
|
||||
if selected_model3 > 0:
|
||||
self.merge_method.values = ['add_difference ( A+(B-C) )']
|
||||
self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one.
|
||||
else:
|
||||
self.merge_method.values = self.interpolations
|
||||
self.merge_method.value = 0
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values() and self.check_for_overwrite():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.merge_arguments = self.marshall_arguments()
|
||||
npyscreen.notify("Starting the merge...")
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def on_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
model_names = self.model_names
|
||||
models = [
|
||||
model_names[self.model1.value[0]],
|
||||
model_names[self.model2.value[0]],
|
||||
]
|
||||
if self.model3.value[0] > 0:
|
||||
models.append(model_names[self.model3.value[0] - 1])
|
||||
interp='add_difference'
|
||||
else:
|
||||
interp=self.interpolations[self.merge_method.value[0]]
|
||||
|
||||
args = dict(
|
||||
models=models,
|
||||
alpha=self.alpha.value,
|
||||
interp=interp,
|
||||
force=self.force.value,
|
||||
merged_model_name=self.merged_model_name.value,
|
||||
)
|
||||
return args
|
||||
|
||||
def check_for_overwrite(self) -> bool:
|
||||
model_out = self.merged_model_name.value
|
||||
if model_out not in self.model_names:
|
||||
return True
|
||||
else:
|
||||
return npyscreen.notify_yes_no(
|
||||
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
|
||||
)
|
||||
|
||||
def validate_field_values(self) -> bool:
|
||||
bad_fields = []
|
||||
model_names = self.model_names
|
||||
selected_models = set(
|
||||
(model_names[self.model1.value[0]], model_names[self.model2.value[0]])
|
||||
)
|
||||
if self.model3.value[0] > 0:
|
||||
selected_models.add(model_names[self.model3.value[0] - 1])
|
||||
if len(selected_models) < 2:
|
||||
bad_fields.append(
|
||||
f"Please select two or three DIFFERENT models to compare. You selected {selected_models}"
|
||||
)
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:"
|
||||
for problem in bad_fields:
|
||||
message += f"\n* {problem}"
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> List[str]:
|
||||
model_names = [
|
||||
name
|
||||
for name in self.model_manager.model_names()
|
||||
if self.model_manager.model_info(name).get("format") == "diffusers"
|
||||
]
|
||||
print(model_names)
|
||||
return sorted(model_names)
|
||||
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
conf = OmegaConf.load(global_config_file())
|
||||
self.model_manager = ModelManager(
|
||||
conf, "cpu", "float16"
|
||||
) # precision doesn't really matter here
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
||||
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
|
||||
|
||||
|
||||
def run_gui(args: Namespace):
|
||||
mergeapp = Mergeapp()
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
merge_diffusion_models_and_commit(**args)
|
||||
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
|
||||
|
||||
|
||||
def run_cli(args: Namespace):
|
||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||
assert (
|
||||
args.models and len(args.models) >= 1 and len(args.models) <= 3
|
||||
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.models)
|
||||
print(
|
||||
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
)
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
||||
assert (
|
||||
args.clobber or args.merged_model_name not in model_manager.model_names()
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merge_diffusion_models_and_commit(**vars(args))
|
||||
print(f'>> Models merged into new model: "{args.merged_model_name}".')
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
global_set_root(args.root_dir)
|
||||
|
||||
cache_dir = str(global_cache_dir("diffusers"))
|
||||
os.environ[
|
||||
"HF_HOME"
|
||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
||||
args.cache_dir = cache_dir
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
run_gui(args)
|
||||
else:
|
||||
run_cli(args)
|
||||
except widget.NotEnoughSpaceForWidget as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
print(
|
||||
"** You need to have at least two diffusers models defined in models.yaml in order to merge"
|
||||
)
|
||||
else:
|
||||
print("** Not enough room for the user interface. Try making this window larger.")
|
||||
sys.exit(-1)
|
||||
except Exception:
|
||||
print(">> An error occurred:")
|
||||
traceback.print_exc()
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
5
invokeai/frontend/training/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
'''
|
||||
Initialization file for invokeai.frontend.training
|
||||
'''
|
||||
from .textual_inversion import main as invokeai_textual_inversion
|
||||
|
461
invokeai/frontend/training/textual_inversion.py
Executable file
@ -0,0 +1,461 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
This is the frontend to "textual_inversion_training.py".
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import npyscreen
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from invokeai.backend.globals import Globals, global_set_root
|
||||
from ...backend.training import (
|
||||
do_textual_inversion_training,
|
||||
parse_args,
|
||||
)
|
||||
|
||||
TRAINING_DATA = "text-inversion-training-data"
|
||||
TRAINING_DIR = "text-inversion-output"
|
||||
CONF_FILE = "preferences.conf"
|
||||
|
||||
|
||||
class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
resolutions = [512, 768, 1024]
|
||||
lr_schedulers = [
|
||||
"linear",
|
||||
"cosine",
|
||||
"cosine_with_restarts",
|
||||
"polynomial",
|
||||
"constant",
|
||||
"constant_with_warmup",
|
||||
]
|
||||
precisions = ["no", "fp16", "bf16"]
|
||||
learnable_properties = ["object", "style"]
|
||||
|
||||
def __init__(self, parentApp, name, saved_args=None):
|
||||
self.saved_args = saved_args or {}
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
self.model_names, default = self.get_model_names()
|
||||
default_initializer_token = "★"
|
||||
default_placeholder_token = ""
|
||||
saved_args = self.saved_args
|
||||
|
||||
try:
|
||||
default = self.model_names.index(saved_args["model"])
|
||||
except:
|
||||
pass
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields, cursor arrows to make a selection, and space to toggle checkboxes.",
|
||||
editable=False,
|
||||
)
|
||||
|
||||
self.model = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Model Name:",
|
||||
values=self.model_names,
|
||||
value=default,
|
||||
max_height=len(self.model_names) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.placeholder_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Trigger Term:",
|
||||
value="", # saved_args.get('placeholder_token',''), # to restore previous term
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.placeholder_token.when_value_edited = self.initializer_changed
|
||||
self.nextrely -= 1
|
||||
self.nextrelx += 30
|
||||
self.prompt_token = self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
name="Trigger term for use in prompt",
|
||||
value="",
|
||||
editable=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrelx -= 30
|
||||
self.initializer_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Initializer:",
|
||||
value=saved_args.get("initializer_token", default_initializer_token),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.resume_from_checkpoint = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Resume from last saved checkpoint",
|
||||
value=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.learnable_property = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Learnable property:",
|
||||
values=self.learnable_properties,
|
||||
value=self.learnable_properties.index(
|
||||
saved_args.get("learnable_property", "object")
|
||||
),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.train_data_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Data Training Directory:",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=str(
|
||||
saved_args.get(
|
||||
"train_data_dir",
|
||||
Path(Globals.root) / TRAINING_DATA / default_placeholder_token,
|
||||
)
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.output_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Output Destination Directory:",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=str(
|
||||
saved_args.get(
|
||||
"output_dir",
|
||||
Path(Globals.root) / TRAINING_DIR / default_placeholder_token,
|
||||
)
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.resolution = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Image resolution (pixels):",
|
||||
values=self.resolutions,
|
||||
value=self.resolutions.index(saved_args.get("resolution", 512)),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.center_crop = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Center crop images before resizing to resolution",
|
||||
value=saved_args.get("center_crop", False),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.mixed_precision = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Mixed Precision:",
|
||||
values=self.precisions,
|
||||
value=self.precisions.index(saved_args.get("mixed_precision", "fp16")),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.num_train_epochs = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Number of training epochs:",
|
||||
out_of=1000,
|
||||
step=50,
|
||||
lowest=1,
|
||||
value=saved_args.get("num_train_epochs", 100),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.max_train_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Max Training Steps:",
|
||||
out_of=10000,
|
||||
step=500,
|
||||
lowest=1,
|
||||
value=saved_args.get("max_train_steps", 3000),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.train_batch_size = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Batch Size (reduce if you run out of memory):",
|
||||
out_of=50,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get("train_batch_size", 8),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.gradient_accumulation_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):",
|
||||
out_of=10,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get("gradient_accumulation_steps", 4),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lr_warmup_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Warmup Steps:",
|
||||
out_of=100,
|
||||
step=1,
|
||||
lowest=0,
|
||||
value=saved_args.get("lr_warmup_steps", 0),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.learning_rate = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Learning Rate:",
|
||||
value=str(
|
||||
saved_args.get("learning_rate", "5.0e-04"),
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.scale_lr = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Scale learning rate by number GPUs, steps and batch size",
|
||||
value=saved_args.get("scale_lr", True),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Use xformers acceleration",
|
||||
value=saved_args.get("enable_xformers_memory_efficient_attention", False),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lr_scheduler = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Learning rate scheduler:",
|
||||
values=self.lr_schedulers,
|
||||
max_height=7,
|
||||
value=self.lr_schedulers.index(saved_args.get("lr_scheduler", "constant")),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.model.editing = True
|
||||
|
||||
def initializer_changed(self):
|
||||
placeholder = self.placeholder_token.value
|
||||
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
||||
self.train_data_dir.value = str(
|
||||
Path(Globals.root) / TRAINING_DATA / placeholder
|
||||
)
|
||||
self.output_dir.value = str(Path(Globals.root) / TRAINING_DIR / placeholder)
|
||||
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.ti_arguments = self.marshall_arguments()
|
||||
npyscreen.notify(
|
||||
"Launching textual inversion training. This will take a while..."
|
||||
)
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def ok_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def validate_field_values(self) -> bool:
|
||||
bad_fields = []
|
||||
if self.model.value is None:
|
||||
bad_fields.append(
|
||||
"Model Name must correspond to a known model in models.yaml"
|
||||
)
|
||||
if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value):
|
||||
bad_fields.append(
|
||||
"Trigger term must only contain alphanumeric characters, the dot and hyphen"
|
||||
)
|
||||
if self.train_data_dir.value is None:
|
||||
bad_fields.append("Data Training Directory cannot be empty")
|
||||
if self.output_dir.value is None:
|
||||
bad_fields.append("The Output Destination Directory cannot be empty")
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:"
|
||||
for problem in bad_fields:
|
||||
message += f"\n* {problem}"
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> Tuple[List[str], int]:
|
||||
conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
|
||||
model_names = [
|
||||
idx
|
||||
for idx in sorted(list(conf.keys()))
|
||||
if conf[idx].get("format", None) == "diffusers"
|
||||
]
|
||||
defaults = [
|
||||
idx
|
||||
for idx in range(len(model_names))
|
||||
if "default" in conf[model_names[idx]]
|
||||
]
|
||||
default = defaults[0] if len(defaults) > 0 else 0
|
||||
return (model_names, default)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
args = dict()
|
||||
|
||||
# the choices
|
||||
args.update(
|
||||
model=self.model_names[self.model.value[0]],
|
||||
resolution=self.resolutions[self.resolution.value[0]],
|
||||
lr_scheduler=self.lr_schedulers[self.lr_scheduler.value[0]],
|
||||
mixed_precision=self.precisions[self.mixed_precision.value[0]],
|
||||
learnable_property=self.learnable_properties[
|
||||
self.learnable_property.value[0]
|
||||
],
|
||||
)
|
||||
|
||||
# all the strings and booleans
|
||||
for attr in (
|
||||
"initializer_token",
|
||||
"placeholder_token",
|
||||
"train_data_dir",
|
||||
"output_dir",
|
||||
"scale_lr",
|
||||
"center_crop",
|
||||
"enable_xformers_memory_efficient_attention",
|
||||
):
|
||||
args[attr] = getattr(self, attr).value
|
||||
|
||||
# all the integers
|
||||
for attr in (
|
||||
"train_batch_size",
|
||||
"gradient_accumulation_steps",
|
||||
"num_train_epochs",
|
||||
"max_train_steps",
|
||||
"lr_warmup_steps",
|
||||
):
|
||||
args[attr] = int(getattr(self, attr).value)
|
||||
|
||||
# the floats (just one)
|
||||
args.update(learning_rate=float(self.learning_rate.value))
|
||||
|
||||
# a special case
|
||||
if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
|
||||
args["resume_from_checkpoint"] = "latest"
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class MyApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, saved_args=None):
|
||||
super().__init__()
|
||||
self.ti_arguments = None
|
||||
self.saved_args = saved_args
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main = self.addForm(
|
||||
"MAIN",
|
||||
textualInversionForm,
|
||||
name="Textual Inversion Settings",
|
||||
saved_args=self.saved_args,
|
||||
)
|
||||
|
||||
|
||||
def copy_to_embeddings_folder(args: dict):
|
||||
"""
|
||||
Copy learned_embeds.bin into the embeddings folder, and offer to
|
||||
delete the full model and checkpoints.
|
||||
"""
|
||||
source = Path(args["output_dir"], "learned_embeds.bin")
|
||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
||||
os.makedirs(destination, exist_ok=True)
|
||||
print(f">> Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||
shutil.copy(source, destination)
|
||||
if (
|
||||
input("Delete training logs and intermediate checkpoints? [y] ") or "y"
|
||||
).startswith(("y", "Y")):
|
||||
shutil.rmtree(Path(args["output_dir"]))
|
||||
else:
|
||||
print(f'>> Keeping {args["output_dir"]}')
|
||||
|
||||
|
||||
def save_args(args: dict):
|
||||
"""
|
||||
Save the current argument values to an omegaconf file
|
||||
"""
|
||||
dest_dir = Path(Globals.root) / TRAINING_DIR
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
conf_file = dest_dir / CONF_FILE
|
||||
conf = OmegaConf.create(args)
|
||||
OmegaConf.save(config=conf, f=conf_file)
|
||||
|
||||
|
||||
def previous_args() -> dict:
|
||||
"""
|
||||
Get the previous arguments used.
|
||||
"""
|
||||
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
|
||||
try:
|
||||
conf = OmegaConf.load(conf_file)
|
||||
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
||||
except:
|
||||
conf = None
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
def do_front_end(args: Namespace):
|
||||
saved_args = previous_args()
|
||||
myapplication = MyApplication(saved_args=saved_args)
|
||||
myapplication.run()
|
||||
|
||||
if args := myapplication.ti_arguments:
|
||||
os.makedirs(args["output_dir"], exist_ok=True)
|
||||
|
||||
# Automatically add angle brackets around the trigger
|
||||
if not re.match("^<.+>$", args["placeholder_token"]):
|
||||
args["placeholder_token"] = f"<{args['placeholder_token']}>"
|
||||
|
||||
args["only_save_embeds"] = True
|
||||
save_args(args)
|
||||
|
||||
try:
|
||||
do_textual_inversion_training(**args)
|
||||
copy_to_embeddings_folder(args)
|
||||
except Exception as e:
|
||||
print("** An exception occurred during training. The exception was:")
|
||||
print(str(e))
|
||||
print("** DETAILS:")
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
global_set_root(args.root_dir or Globals.root)
|
||||
try:
|
||||
if args.front_end:
|
||||
do_front_end(args)
|
||||
else:
|
||||
do_textual_inversion_training(**vars(args))
|
||||
except AssertionError as e:
|
||||
print(str(e))
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
print(
|
||||
"** You need to have at least one diffusers models defined in models.yaml in order to train"
|
||||
)
|
||||
elif str(e).startswith('addwstr'):
|
||||
print(
|
||||
'** Not enough window space for the interface. Please make your window larger and try again.'
|
||||
)
|
||||
else:
|
||||
print(f"** An error has occurred: {str(e)}")
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
3
invokeai/frontend/web/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
'''
|
||||
Initialization file for invokeai.frontend.web
|
||||
'''
|
Before Width: | Height: | Size: 116 KiB After Width: | Height: | Size: 116 KiB |
Before Width: | Height: | Size: 43 KiB After Width: | Height: | Size: 43 KiB |
Before Width: | Height: | Size: 116 KiB After Width: | Height: | Size: 116 KiB |
Before Width: | Height: | Size: 336 KiB After Width: | Height: | Size: 336 KiB |
Before Width: | Height: | Size: 43 KiB After Width: | Height: | Size: 43 KiB |
Before Width: | Height: | Size: 3.7 KiB After Width: | Height: | Size: 3.7 KiB |