mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
all vestiges of ldm.invoke removed
This commit is contained in:
@ -8,7 +8,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import click
|
||||
|
||||
from compel import PromptParser
|
||||
|
||||
if sys.platform == "darwin":
|
||||
@ -18,22 +17,23 @@ import pyparsing # type: ignore
|
||||
|
||||
import invokeai.version
|
||||
|
||||
from ...backend import Generate
|
||||
from ...backend.args import (Args,
|
||||
dream_cmd_from_png,
|
||||
metadata_dumps,
|
||||
metadata_from_png)
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.image_util import make_grid, PngWriter, retrieve_metadata, write_metadata
|
||||
from ...backend import ModelManager
|
||||
from ...backend import Generate, ModelManager
|
||||
from ...backend.args import Args, dream_cmd_from_png, metadata_dumps, metadata_from_png
|
||||
from ...backend.globals import Globals
|
||||
from ...backend.util import write_log
|
||||
from ...backend.image_util import (
|
||||
PngWriter,
|
||||
make_grid,
|
||||
retrieve_metadata,
|
||||
write_metadata,
|
||||
)
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util import url_attachment_name, write_log
|
||||
from .readline import Completer, get_completer
|
||||
from ...backend.util import url_attachment_name
|
||||
|
||||
# global used in multiple functions (fix)
|
||||
infile = None
|
||||
|
||||
|
||||
def main():
|
||||
"""Initialize command-line parsers and the diffusion model"""
|
||||
global infile
|
||||
@ -494,7 +494,7 @@ def main_loop(gen, opt):
|
||||
def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
global infile
|
||||
operation = "generate" # default operation, alternative is 'postprocess'
|
||||
command = command.replace('\\','/') # windows
|
||||
command = command.replace("\\", "/") # windows
|
||||
|
||||
if command.startswith(
|
||||
"!dream"
|
||||
@ -537,10 +537,10 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
import_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
except KeyboardInterrupt:
|
||||
print('\n')
|
||||
print("\n")
|
||||
operation = None
|
||||
|
||||
elif command.startswith(("!convert","!optimize")):
|
||||
elif command.startswith(("!convert", "!optimize")):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print("** please provide the path to a .ckpt or .safetensors model")
|
||||
@ -549,9 +549,9 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
convert_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
except KeyboardInterrupt:
|
||||
print('\n')
|
||||
print("\n")
|
||||
operation = None
|
||||
|
||||
|
||||
elif command.startswith("!edit"):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
@ -639,12 +639,12 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
|
||||
):
|
||||
pass
|
||||
else:
|
||||
if model_path.startswith(('http:','https:')):
|
||||
if model_path.startswith(("http:", "https:")):
|
||||
try:
|
||||
default_name = url_attachment_name(model_path)
|
||||
default_name = Path(default_name).stem
|
||||
except Exception as e:
|
||||
print(f'** URL: {str(e)}')
|
||||
print(f"** URL: {str(e)}")
|
||||
model_name, model_desc = _get_model_name_and_desc(
|
||||
gen.model_manager,
|
||||
completer,
|
||||
@ -672,6 +672,7 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
print(f">> {imported_name} successfully installed")
|
||||
|
||||
|
||||
def _verify_load(model_name: str, gen) -> bool:
|
||||
print(">> Verifying that new model loads...")
|
||||
current_model = gen.model_name
|
||||
@ -704,6 +705,7 @@ def _get_model_name_and_desc(
|
||||
)
|
||||
return model_name, model_description
|
||||
|
||||
|
||||
def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
||||
model_name_or_path = model_name_or_path.replace("\\", "/") # windows
|
||||
manager = gen.model_manager
|
||||
@ -722,7 +724,9 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
||||
else:
|
||||
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
|
||||
return
|
||||
if vae_repo := ldm.invoke.model_manager.VAE_TO_REPO_ID.get(Path(vae).stem):
|
||||
if vae_repo := invokeai.backend.model_management.model_manager.VAE_TO_REPO_ID.get(
|
||||
Path(vae).stem
|
||||
):
|
||||
vae_repo = dict(repo_id=vae_repo)
|
||||
else:
|
||||
vae_repo = None
|
||||
@ -742,7 +746,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
||||
except KeyboardInterrupt:
|
||||
return
|
||||
|
||||
manager.commit(opt.conf)
|
||||
manager.commit(opt.conf)
|
||||
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
|
||||
ckpt_path.unlink(missing_ok=True)
|
||||
print(f"{ckpt_path} deleted")
|
||||
@ -1106,7 +1110,7 @@ def make_step_callback(gen, opt, prefix):
|
||||
if step % opt.save_intermediates == 0 or step == opt.steps - 1:
|
||||
filename = os.path.join(destination, f"{step:04}.png")
|
||||
image = gen.sample_to_lowres_estimated_image(latents)
|
||||
image = image.resize((image.size[0]*8,image.size[1]*8))
|
||||
image = image.resize((image.size[0] * 8, image.size[1] * 8))
|
||||
image.save(filename, "PNG")
|
||||
|
||||
return callback
|
||||
@ -1190,8 +1194,8 @@ def report_model_error(opt: Namespace, e: Exception):
|
||||
)
|
||||
else:
|
||||
if not click.confirm(
|
||||
'Do you want to run invokeai-configure script to select and/or reinstall models?',
|
||||
default=False
|
||||
"Do you want to run invokeai-configure script to select and/or reinstall models?",
|
||||
default=False,
|
||||
):
|
||||
return
|
||||
|
||||
@ -1209,9 +1213,9 @@ def report_model_error(opt: Namespace, e: Exception):
|
||||
for arg in yes_to_all.split():
|
||||
sys.argv.append(arg)
|
||||
|
||||
from ldm.invoke.config import invokeai_configure
|
||||
from ..install import invokeai_configure
|
||||
|
||||
invokeai_configure.main()
|
||||
invokeai_configure()
|
||||
print("** InvokeAI will now restart")
|
||||
sys.argv = previous_args
|
||||
main() # would rather do a os.exec(), but doesn't exist?
|
||||
@ -1232,6 +1236,6 @@ def check_internet() -> bool:
|
||||
except:
|
||||
return False
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
'''
|
||||
"""
|
||||
Initialization file for invokeai.frontend.CLI
|
||||
'''
|
||||
"""
|
||||
from .CLI import main as invokeai_command_line_interface
|
||||
|
@ -4,13 +4,14 @@ You may import the global singleton `completer` to get access to the
|
||||
completer object itself. This is useful when you want to autocomplete
|
||||
seeds:
|
||||
|
||||
from ldm.invoke.readline import completer
|
||||
from invokeai.frontend.CLI.readline import completer
|
||||
completer.add_seed(18247566)
|
||||
completer.add_seed(9281839)
|
||||
"""
|
||||
import atexit
|
||||
import os
|
||||
import re
|
||||
import atexit
|
||||
|
||||
from ...backend.args import Args
|
||||
from ...backend.globals import Globals
|
||||
from ...backend.stable_diffusion import HuggingFaceConceptsLibrary
|
||||
@ -18,92 +19,128 @@ from ...backend.stable_diffusion import HuggingFaceConceptsLibrary
|
||||
# ---------------readline utilities---------------------
|
||||
try:
|
||||
import readline
|
||||
|
||||
readline_available = True
|
||||
except (ImportError,ModuleNotFoundError) as e:
|
||||
print(f'** An error occurred when loading the readline module: {str(e)}')
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
print(f"** An error occurred when loading the readline module: {str(e)}")
|
||||
readline_available = False
|
||||
|
||||
IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF')
|
||||
WEIGHT_EXTENSIONS = ('.ckpt','.vae','.safetensors')
|
||||
TEXT_EXTENSIONS = ('.txt','.TXT')
|
||||
CONFIG_EXTENSIONS = ('.yaml','.yml')
|
||||
IMG_EXTENSIONS = (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG", ".gif", ".GIF")
|
||||
WEIGHT_EXTENSIONS = (".ckpt", ".vae", ".safetensors")
|
||||
TEXT_EXTENSIONS = (".txt", ".TXT")
|
||||
CONFIG_EXTENSIONS = (".yaml", ".yml")
|
||||
COMMANDS = (
|
||||
'--steps','-s',
|
||||
'--seed','-S',
|
||||
'--iterations','-n',
|
||||
'--width','-W','--height','-H',
|
||||
'--cfg_scale','-C',
|
||||
'--threshold',
|
||||
'--perlin',
|
||||
'--grid','-g',
|
||||
'--individual','-i',
|
||||
'--save_intermediates',
|
||||
'--init_img','-I',
|
||||
'--init_mask','-M',
|
||||
'--init_color',
|
||||
'--strength','-f',
|
||||
'--variants','-v',
|
||||
'--outdir','-o',
|
||||
'--sampler','-A','-m',
|
||||
'--embedding_path',
|
||||
'--device',
|
||||
'--grid','-g',
|
||||
'--facetool','-ft',
|
||||
'--facetool_strength','-G',
|
||||
'--codeformer_fidelity','-cf',
|
||||
'--upscale','-U',
|
||||
'-save_orig','--save_original',
|
||||
'--log_tokenization','-t',
|
||||
'--hires_fix',
|
||||
'--inpaint_replace','-r',
|
||||
'--png_compression','-z',
|
||||
'--text_mask','-tm',
|
||||
'--h_symmetry_time_pct',
|
||||
'--v_symmetry_time_pct',
|
||||
'!fix','!fetch','!replay','!history','!search','!clear',
|
||||
'!models','!switch','!import_model','!optimize_model','!convert_model','!edit_model','!del_model',
|
||||
'!mask','!triggers',
|
||||
)
|
||||
MODEL_COMMANDS = (
|
||||
'!switch',
|
||||
'!edit_model',
|
||||
'!del_model',
|
||||
)
|
||||
CKPT_MODEL_COMMANDS = (
|
||||
'!optimize_model',
|
||||
"--steps",
|
||||
"-s",
|
||||
"--seed",
|
||||
"-S",
|
||||
"--iterations",
|
||||
"-n",
|
||||
"--width",
|
||||
"-W",
|
||||
"--height",
|
||||
"-H",
|
||||
"--cfg_scale",
|
||||
"-C",
|
||||
"--threshold",
|
||||
"--perlin",
|
||||
"--grid",
|
||||
"-g",
|
||||
"--individual",
|
||||
"-i",
|
||||
"--save_intermediates",
|
||||
"--init_img",
|
||||
"-I",
|
||||
"--init_mask",
|
||||
"-M",
|
||||
"--init_color",
|
||||
"--strength",
|
||||
"-f",
|
||||
"--variants",
|
||||
"-v",
|
||||
"--outdir",
|
||||
"-o",
|
||||
"--sampler",
|
||||
"-A",
|
||||
"-m",
|
||||
"--embedding_path",
|
||||
"--device",
|
||||
"--grid",
|
||||
"-g",
|
||||
"--facetool",
|
||||
"-ft",
|
||||
"--facetool_strength",
|
||||
"-G",
|
||||
"--codeformer_fidelity",
|
||||
"-cf",
|
||||
"--upscale",
|
||||
"-U",
|
||||
"-save_orig",
|
||||
"--save_original",
|
||||
"--log_tokenization",
|
||||
"-t",
|
||||
"--hires_fix",
|
||||
"--inpaint_replace",
|
||||
"-r",
|
||||
"--png_compression",
|
||||
"-z",
|
||||
"--text_mask",
|
||||
"-tm",
|
||||
"--h_symmetry_time_pct",
|
||||
"--v_symmetry_time_pct",
|
||||
"!fix",
|
||||
"!fetch",
|
||||
"!replay",
|
||||
"!history",
|
||||
"!search",
|
||||
"!clear",
|
||||
"!models",
|
||||
"!switch",
|
||||
"!import_model",
|
||||
"!optimize_model",
|
||||
"!convert_model",
|
||||
"!edit_model",
|
||||
"!del_model",
|
||||
"!mask",
|
||||
"!triggers",
|
||||
)
|
||||
MODEL_COMMANDS = (
|
||||
"!switch",
|
||||
"!edit_model",
|
||||
"!del_model",
|
||||
)
|
||||
CKPT_MODEL_COMMANDS = ("!optimize_model",)
|
||||
WEIGHT_COMMANDS = (
|
||||
'!import_model',
|
||||
'!convert_model',
|
||||
)
|
||||
IMG_PATH_COMMANDS = (
|
||||
'--outdir[=\s]',
|
||||
)
|
||||
TEXT_PATH_COMMANDS=(
|
||||
'!replay',
|
||||
)
|
||||
IMG_FILE_COMMANDS=(
|
||||
'!fix',
|
||||
'!fetch',
|
||||
'!mask',
|
||||
'--init_img[=\s]','-I',
|
||||
'--init_mask[=\s]','-M',
|
||||
'--init_color[=\s]',
|
||||
'--embedding_path[=\s]',
|
||||
)
|
||||
"!import_model",
|
||||
"!convert_model",
|
||||
)
|
||||
IMG_PATH_COMMANDS = ("--outdir[=\s]",)
|
||||
TEXT_PATH_COMMANDS = ("!replay",)
|
||||
IMG_FILE_COMMANDS = (
|
||||
"!fix",
|
||||
"!fetch",
|
||||
"!mask",
|
||||
"--init_img[=\s]",
|
||||
"-I",
|
||||
"--init_mask[=\s]",
|
||||
"-M",
|
||||
"--init_color[=\s]",
|
||||
"--embedding_path[=\s]",
|
||||
)
|
||||
|
||||
path_regexp = "(" + "|".join(IMG_PATH_COMMANDS + IMG_FILE_COMMANDS) + ")\s*\S*$"
|
||||
weight_regexp = "(" + "|".join(WEIGHT_COMMANDS) + ")\s*\S*$"
|
||||
text_regexp = "(" + "|".join(TEXT_PATH_COMMANDS) + ")\s*\S*$"
|
||||
|
||||
path_regexp = '(' + '|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
|
||||
weight_regexp = '(' + '|'.join(WEIGHT_COMMANDS) + ')\s*\S*$'
|
||||
text_regexp = '(' + '|'.join(TEXT_PATH_COMMANDS) + ')\s*\S*$'
|
||||
|
||||
class Completer(object):
|
||||
def __init__(self, options, models={}):
|
||||
self.options = sorted(options)
|
||||
self.models = models
|
||||
self.seeds = set()
|
||||
self.matches = list()
|
||||
self.options = sorted(options)
|
||||
self.models = models
|
||||
self.seeds = set()
|
||||
self.matches = list()
|
||||
self.default_dir = None
|
||||
self.linebuffer = None
|
||||
self.linebuffer = None
|
||||
self.auto_history_active = True
|
||||
self.extensions = None
|
||||
self.concepts = None
|
||||
@ -111,40 +148,41 @@ class Completer(object):
|
||||
return
|
||||
|
||||
def complete(self, text, state):
|
||||
'''
|
||||
"""
|
||||
Completes invoke command line.
|
||||
BUG: it doesn't correctly complete files that have spaces in the name.
|
||||
'''
|
||||
"""
|
||||
buffer = readline.get_line_buffer()
|
||||
|
||||
if state == 0:
|
||||
|
||||
# extensions defined, so go directly into path completion mode
|
||||
if self.extensions is not None:
|
||||
self.matches = self._path_completions(text, state, self.extensions)
|
||||
|
||||
# looking for an image file
|
||||
elif re.search(path_regexp,buffer):
|
||||
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
|
||||
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
|
||||
elif re.search(path_regexp, buffer):
|
||||
do_shortcut = re.search("^" + "|".join(IMG_FILE_COMMANDS), buffer)
|
||||
self.matches = self._path_completions(
|
||||
text, state, IMG_EXTENSIONS, shortcut_ok=do_shortcut
|
||||
)
|
||||
|
||||
# looking for a seed
|
||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||
self.matches= self._seed_completions(text,state)
|
||||
elif re.search("(-S\s*|--seed[=\s])\d*$", buffer):
|
||||
self.matches = self._seed_completions(text, state)
|
||||
|
||||
# looking for an embedding concept
|
||||
elif re.search('<[\w-]*$',buffer):
|
||||
self.matches= self._concept_completions(text,state)
|
||||
elif re.search("<[\w-]*$", buffer):
|
||||
self.matches = self._concept_completions(text, state)
|
||||
|
||||
# looking for a model
|
||||
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state)
|
||||
elif re.match("^" + "|".join(MODEL_COMMANDS), buffer):
|
||||
self.matches = self._model_completions(text, state)
|
||||
|
||||
# looking for a ckpt model
|
||||
elif re.match('^'+'|'.join(CKPT_MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state, ckpt_only=True)
|
||||
elif re.match("^" + "|".join(CKPT_MODEL_COMMANDS), buffer):
|
||||
self.matches = self._model_completions(text, state, ckpt_only=True)
|
||||
|
||||
elif re.search(weight_regexp,buffer):
|
||||
elif re.search(weight_regexp, buffer):
|
||||
self.matches = self._path_completions(
|
||||
text,
|
||||
state,
|
||||
@ -152,14 +190,12 @@ class Completer(object):
|
||||
default_dir=Globals.root,
|
||||
)
|
||||
|
||||
elif re.search(text_regexp,buffer):
|
||||
elif re.search(text_regexp, buffer):
|
||||
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
|
||||
|
||||
# This is the first time for this text, so build a match list.
|
||||
elif text:
|
||||
self.matches = [
|
||||
s for s in self.options if s and s.startswith(text)
|
||||
]
|
||||
self.matches = [s for s in self.options if s and s.startswith(text)]
|
||||
else:
|
||||
self.matches = self.options[:]
|
||||
|
||||
@ -171,50 +207,50 @@ class Completer(object):
|
||||
response = None
|
||||
return response
|
||||
|
||||
def complete_extensions(self, extensions:list):
|
||||
'''
|
||||
def complete_extensions(self, extensions: list):
|
||||
"""
|
||||
If called with a list of extensions, will force completer
|
||||
to do file path completions.
|
||||
'''
|
||||
self.extensions=extensions
|
||||
"""
|
||||
self.extensions = extensions
|
||||
|
||||
def add_history(self,line):
|
||||
'''
|
||||
def add_history(self, line):
|
||||
"""
|
||||
Pass thru to readline
|
||||
'''
|
||||
"""
|
||||
if not self.auto_history_active:
|
||||
readline.add_history(line)
|
||||
|
||||
def clear_history(self):
|
||||
'''
|
||||
"""
|
||||
Pass clear_history() thru to readline
|
||||
'''
|
||||
"""
|
||||
readline.clear_history()
|
||||
|
||||
def search_history(self,match:str):
|
||||
'''
|
||||
def search_history(self, match: str):
|
||||
"""
|
||||
Like show_history() but only shows items that
|
||||
contain the match string.
|
||||
'''
|
||||
"""
|
||||
self.show_history(match)
|
||||
|
||||
def remove_history_item(self,pos):
|
||||
def remove_history_item(self, pos):
|
||||
readline.remove_history_item(pos)
|
||||
|
||||
def add_seed(self, seed):
|
||||
'''
|
||||
"""
|
||||
Add a seed to the autocomplete list for display when -S is autocompleted.
|
||||
'''
|
||||
"""
|
||||
if seed is not None:
|
||||
self.seeds.add(str(seed))
|
||||
|
||||
def set_default_dir(self, path):
|
||||
self.default_dir=path
|
||||
self.default_dir = path
|
||||
|
||||
def set_options(self,options):
|
||||
def set_options(self, options):
|
||||
self.options = options
|
||||
|
||||
def get_line(self,index):
|
||||
def get_line(self, index):
|
||||
try:
|
||||
line = self.get_history_item(index)
|
||||
except IndexError:
|
||||
@ -224,57 +260,58 @@ class Completer(object):
|
||||
def get_current_history_length(self):
|
||||
return readline.get_current_history_length()
|
||||
|
||||
def get_history_item(self,index):
|
||||
def get_history_item(self, index):
|
||||
return readline.get_history_item(index)
|
||||
|
||||
def show_history(self,match=None):
|
||||
'''
|
||||
def show_history(self, match=None):
|
||||
"""
|
||||
Print the session history using the pydoc pager
|
||||
'''
|
||||
"""
|
||||
import pydoc
|
||||
|
||||
lines = list()
|
||||
h_len = self.get_current_history_length()
|
||||
if h_len < 1:
|
||||
print('<empty history>')
|
||||
print("<empty history>")
|
||||
return
|
||||
|
||||
for i in range(0,h_len):
|
||||
line = self.get_history_item(i+1)
|
||||
for i in range(0, h_len):
|
||||
line = self.get_history_item(i + 1)
|
||||
if match and match not in line:
|
||||
continue
|
||||
lines.append(f'[{i+1}] {line}')
|
||||
pydoc.pager('\n'.join(lines))
|
||||
lines.append(f"[{i+1}] {line}")
|
||||
pydoc.pager("\n".join(lines))
|
||||
|
||||
def set_line(self,line)->None:
|
||||
'''
|
||||
def set_line(self, line) -> None:
|
||||
"""
|
||||
Set the default string displayed in the next line of input.
|
||||
'''
|
||||
"""
|
||||
self.linebuffer = line
|
||||
readline.redisplay()
|
||||
|
||||
def update_models(self,models:dict)->None:
|
||||
'''
|
||||
def update_models(self, models: dict) -> None:
|
||||
"""
|
||||
update our list of models
|
||||
'''
|
||||
"""
|
||||
self.models = models
|
||||
|
||||
def _seed_completions(self, text, state):
|
||||
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
||||
m = re.search("(-S\s?|--seed[=\s]?)(\d*)", text)
|
||||
if m:
|
||||
switch = m.groups()[0]
|
||||
switch = m.groups()[0]
|
||||
partial = m.groups()[1]
|
||||
else:
|
||||
switch = ''
|
||||
switch = ""
|
||||
partial = text
|
||||
|
||||
matches = list()
|
||||
for s in self.seeds:
|
||||
if s.startswith(partial):
|
||||
matches.append(switch+s)
|
||||
matches.append(switch + s)
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def add_embedding_terms(self, terms:list[str]):
|
||||
def add_embedding_terms(self, terms: list[str]):
|
||||
self.embedding_terms = set(terms)
|
||||
if self.concepts:
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
@ -294,27 +331,27 @@ class Completer(object):
|
||||
matches = list()
|
||||
for concept in self.embedding_terms:
|
||||
if concept.startswith(partial):
|
||||
matches.append(f'<{concept}>')
|
||||
matches.append(f"<{concept}>")
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def _model_completions(self, text, state, ckpt_only=False):
|
||||
m = re.search('(!switch\s+)(\w*)',text)
|
||||
m = re.search("(!switch\s+)(\w*)", text)
|
||||
if m:
|
||||
switch = m.groups()[0]
|
||||
switch = m.groups()[0]
|
||||
partial = m.groups()[1]
|
||||
else:
|
||||
switch = ''
|
||||
switch = ""
|
||||
partial = text
|
||||
matches = list()
|
||||
for s in self.models:
|
||||
format = self.models[s]['format']
|
||||
if format == 'vae':
|
||||
format = self.models[s]["format"]
|
||||
if format == "vae":
|
||||
continue
|
||||
if ckpt_only and format != 'ckpt':
|
||||
if ckpt_only and format != "ckpt":
|
||||
continue
|
||||
if s.startswith(partial):
|
||||
matches.append(switch+s)
|
||||
matches.append(switch + s)
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
@ -324,14 +361,16 @@ class Completer(object):
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
def _path_completions(self, text, state, extensions, shortcut_ok=True, default_dir:str=''):
|
||||
def _path_completions(
|
||||
self, text, state, extensions, shortcut_ok=True, default_dir: str = ""
|
||||
):
|
||||
# separate the switch from the partial path
|
||||
match = re.search('^(-\w|--\w+=?)(.*)',text)
|
||||
match = re.search("^(-\w|--\w+=?)(.*)", text)
|
||||
if match is None:
|
||||
switch = None
|
||||
partial_path = text
|
||||
else:
|
||||
switch,partial_path = match.groups()
|
||||
switch, partial_path = match.groups()
|
||||
|
||||
partial_path = partial_path.lstrip()
|
||||
|
||||
@ -340,18 +379,18 @@ class Completer(object):
|
||||
|
||||
if os.path.isdir(path):
|
||||
dir = path
|
||||
elif os.path.dirname(path) != '':
|
||||
elif os.path.dirname(path) != "":
|
||||
dir = os.path.dirname(path)
|
||||
else:
|
||||
dir = default_dir if os.path.exists(default_dir) else ''
|
||||
path= os.path.join(dir,path)
|
||||
dir = default_dir if os.path.exists(default_dir) else ""
|
||||
path = os.path.join(dir, path)
|
||||
|
||||
dir_list = os.listdir(dir or '.')
|
||||
if shortcut_ok and os.path.exists(self.default_dir) and dir=='':
|
||||
dir_list = os.listdir(dir or ".")
|
||||
if shortcut_ok and os.path.exists(self.default_dir) and dir == "":
|
||||
dir_list += os.listdir(self.default_dir)
|
||||
|
||||
for node in dir_list:
|
||||
if node.startswith('.') and len(node) > 1:
|
||||
if node.startswith(".") and len(node) > 1:
|
||||
continue
|
||||
full_path = os.path.join(dir, node)
|
||||
|
||||
@ -362,25 +401,26 @@ class Completer(object):
|
||||
continue
|
||||
|
||||
if switch is None:
|
||||
match_path = os.path.join(dir,node)
|
||||
matches.append(match_path+'/' if os.path.isdir(full_path) else match_path)
|
||||
match_path = os.path.join(dir, node)
|
||||
matches.append(
|
||||
match_path + "/" if os.path.isdir(full_path) else match_path
|
||||
)
|
||||
elif os.path.isdir(full_path):
|
||||
matches.append(
|
||||
switch+os.path.join(os.path.dirname(full_path), node) + '/'
|
||||
switch + os.path.join(os.path.dirname(full_path), node) + "/"
|
||||
)
|
||||
elif node.endswith(extensions):
|
||||
matches.append(
|
||||
switch+os.path.join(os.path.dirname(full_path), node)
|
||||
)
|
||||
matches.append(switch + os.path.join(os.path.dirname(full_path), node))
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
class DummyCompleter(Completer):
|
||||
def __init__(self,options):
|
||||
def __init__(self, options):
|
||||
super().__init__(options)
|
||||
self.history = list()
|
||||
|
||||
def add_history(self,line):
|
||||
def add_history(self, line):
|
||||
self.history.append(line)
|
||||
|
||||
def clear_history(self):
|
||||
@ -389,37 +429,37 @@ class DummyCompleter(Completer):
|
||||
def get_current_history_length(self):
|
||||
return len(self.history)
|
||||
|
||||
def get_history_item(self,index):
|
||||
return self.history[index-1]
|
||||
def get_history_item(self, index):
|
||||
return self.history[index - 1]
|
||||
|
||||
def remove_history_item(self,index):
|
||||
return self.history.pop(index-1)
|
||||
def remove_history_item(self, index):
|
||||
return self.history.pop(index - 1)
|
||||
|
||||
def set_line(self,line):
|
||||
print(f'# {line}')
|
||||
def set_line(self, line):
|
||||
print(f"# {line}")
|
||||
|
||||
def generic_completer(commands:list)->Completer:
|
||||
|
||||
def generic_completer(commands: list) -> Completer:
|
||||
if readline_available:
|
||||
completer = Completer(commands,[])
|
||||
completer = Completer(commands, [])
|
||||
readline.set_completer(completer.complete)
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(' ')
|
||||
readline.parse_and_bind('tab: complete')
|
||||
readline.parse_and_bind('set print-completions-horizontally off')
|
||||
readline.parse_and_bind('set page-completions on')
|
||||
readline.parse_and_bind('set skip-completed-text on')
|
||||
readline.parse_and_bind('set show-all-if-ambiguous on')
|
||||
readline.set_completer_delims(" ")
|
||||
readline.parse_and_bind("tab: complete")
|
||||
readline.parse_and_bind("set print-completions-horizontally off")
|
||||
readline.parse_and_bind("set page-completions on")
|
||||
readline.parse_and_bind("set skip-completed-text on")
|
||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||
else:
|
||||
completer = DummyCompleter(commands)
|
||||
return completer
|
||||
|
||||
def get_completer(opt:Args, models=[])->Completer:
|
||||
if readline_available:
|
||||
completer = Completer(COMMANDS,models)
|
||||
|
||||
readline.set_completer(
|
||||
completer.complete
|
||||
)
|
||||
def get_completer(opt: Args, models=[]) -> Completer:
|
||||
if readline_available:
|
||||
completer = Completer(COMMANDS, models)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
try:
|
||||
readline.set_auto_history(False)
|
||||
@ -427,27 +467,29 @@ def get_completer(opt:Args, models=[])->Completer:
|
||||
except:
|
||||
completer.auto_history_active = True
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(' ')
|
||||
readline.parse_and_bind('tab: complete')
|
||||
readline.parse_and_bind('set print-completions-horizontally off')
|
||||
readline.parse_and_bind('set page-completions on')
|
||||
readline.parse_and_bind('set skip-completed-text on')
|
||||
readline.parse_and_bind('set show-all-if-ambiguous on')
|
||||
readline.set_completer_delims(" ")
|
||||
readline.parse_and_bind("tab: complete")
|
||||
readline.parse_and_bind("set print-completions-horizontally off")
|
||||
readline.parse_and_bind("set page-completions on")
|
||||
readline.parse_and_bind("set skip-completed-text on")
|
||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||
|
||||
outdir = os.path.expanduser(opt.outdir)
|
||||
if os.path.isabs(outdir):
|
||||
histfile = os.path.join(outdir,'.invoke_history')
|
||||
histfile = os.path.join(outdir, ".invoke_history")
|
||||
else:
|
||||
histfile = os.path.join(Globals.root, outdir, '.invoke_history')
|
||||
histfile = os.path.join(Globals.root, outdir, ".invoke_history")
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError: # file likely corrupted
|
||||
newname = f'{histfile}.old'
|
||||
print(f'## Your history file {histfile} couldn\'t be loaded and may be corrupted. Renaming it to {newname}')
|
||||
os.replace(histfile,newname)
|
||||
except OSError: # file likely corrupted
|
||||
newname = f"{histfile}.old"
|
||||
print(
|
||||
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
)
|
||||
os.replace(histfile, newname)
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
|
||||
else:
|
||||
|
@ -1,7 +1,6 @@
|
||||
'''
|
||||
"""
|
||||
Initialization file for invokeai.frontend.config
|
||||
'''
|
||||
from .model_install import main as invokeai_model_install
|
||||
"""
|
||||
from .invokeai_configure import main as invokeai_configure
|
||||
from .invokeai_update import main as invokeai_update
|
||||
|
||||
from .model_install import main as invokeai_model_install
|
@ -1,4 +1,4 @@
|
||||
'''
|
||||
"""
|
||||
Wrapper for invokeai.backend.configure.invokeai_configure
|
||||
'''
|
||||
"""
|
||||
from ...backend.config.invokeai_configure import main
|
@ -1,9 +1,10 @@
|
||||
'''
|
||||
"""
|
||||
Minimalist updater script. Prompts user for the tag or branch to update to and runs
|
||||
pip install <path_to_git_source>.
|
||||
'''
|
||||
"""
|
||||
import os
|
||||
import platform
|
||||
|
||||
import requests
|
||||
from rich import box, print
|
||||
from rich.console import Console, Group, group
|
||||
@ -15,8 +16,8 @@ from rich.text import Text
|
||||
|
||||
from invokeai.version import __version__
|
||||
|
||||
INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive"
|
||||
INVOKE_AI_REL="https://api.github.com/repos/invoke-ai/InvokeAI/releases"
|
||||
INVOKE_AI_SRC = "https://github.com/invoke-ai/InvokeAI/archive"
|
||||
INVOKE_AI_REL = "https://api.github.com/repos/invoke-ai/InvokeAI/releases"
|
||||
|
||||
OS = platform.uname().system
|
||||
ARCH = platform.uname().machine
|
||||
@ -27,21 +28,22 @@ if OS == "Windows":
|
||||
else:
|
||||
console = Console(style=Style(color="grey74", bgcolor="grey19"))
|
||||
|
||||
def get_versions()->dict:
|
||||
|
||||
def get_versions() -> dict:
|
||||
return requests.get(url=INVOKE_AI_REL).json()
|
||||
|
||||
|
||||
def welcome(versions: dict):
|
||||
|
||||
@group()
|
||||
def text():
|
||||
yield f'InvokeAI Version: [bold yellow]{__version__}'
|
||||
yield ''
|
||||
yield 'This script will update InvokeAI to the latest release, or to a development version of your choice.'
|
||||
yield ''
|
||||
yield '[bold yellow]Options:'
|
||||
yield f'''[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic])
|
||||
yield f"InvokeAI Version: [bold yellow]{__version__}"
|
||||
yield ""
|
||||
yield "This script will update InvokeAI to the latest release, or to a development version of your choice."
|
||||
yield ""
|
||||
yield "[bold yellow]Options:"
|
||||
yield f"""[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic])
|
||||
[2] Update to the bleeding-edge development version ([italic]main[/italic])
|
||||
[3] Manually enter the tag or branch name you wish to update'''
|
||||
[3] Manually enter the tag or branch name you wish to update"""
|
||||
|
||||
console.rule()
|
||||
print(
|
||||
@ -57,32 +59,33 @@ def welcome(versions: dict):
|
||||
)
|
||||
console.line()
|
||||
|
||||
|
||||
def main():
|
||||
versions = get_versions()
|
||||
welcome(versions)
|
||||
|
||||
tag = None
|
||||
choice = Prompt.ask('Choice:',choices=['1','2','3'],default='1')
|
||||
|
||||
if choice=='1':
|
||||
tag = versions[0]['tag_name']
|
||||
elif choice=='2':
|
||||
tag = 'main'
|
||||
elif choice=='3':
|
||||
tag = Prompt.ask('Enter an InvokeAI tag or branch name')
|
||||
choice = Prompt.ask("Choice:", choices=["1", "2", "3"], default="1")
|
||||
|
||||
print(f':crossed_fingers: Upgrading to [yellow]{tag}[/yellow]')
|
||||
cmd = f'pip install {INVOKE_AI_SRC}/{tag}.zip --use-pep517'
|
||||
print('')
|
||||
print('')
|
||||
if os.system(cmd)==0:
|
||||
print(f':heavy_check_mark: Upgrade successful')
|
||||
if choice == "1":
|
||||
tag = versions[0]["tag_name"]
|
||||
elif choice == "2":
|
||||
tag = "main"
|
||||
elif choice == "3":
|
||||
tag = Prompt.ask("Enter an InvokeAI tag or branch name")
|
||||
|
||||
print(f":crossed_fingers: Upgrading to [yellow]{tag}[/yellow]")
|
||||
cmd = f"pip install {INVOKE_AI_SRC}/{tag}.zip --use-pep517"
|
||||
print("")
|
||||
print("")
|
||||
if os.system(cmd) == 0:
|
||||
print(f":heavy_check_mark: Upgrade successful")
|
||||
else:
|
||||
print(f':exclamation: [bold red]Upgrade failed[/red bold]')
|
||||
|
||||
print(f":exclamation: [bold red]Upgrade failed[/red bold]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
@ -14,34 +14,42 @@ import os
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import List
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
from shutil import get_terminal_size
|
||||
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.globals import Globals, global_config_dir
|
||||
from ...backend.config.model_install_backend import (Dataset_path, default_config_file,
|
||||
default_dataset, get_root,
|
||||
install_requested_models,
|
||||
recommended_datasets,
|
||||
)
|
||||
from .widgets import (MultiSelectColumns, TextBox,
|
||||
OffsetButtonPress, CenteredTitleText,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
|
||||
from ...backend.config.model_install_backend import (
|
||||
Dataset_path,
|
||||
default_config_file,
|
||||
default_dataset,
|
||||
get_root,
|
||||
install_requested_models,
|
||||
recommended_datasets,
|
||||
)
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from .widgets import (
|
||||
CenteredTitleText,
|
||||
MultiSelectColumns,
|
||||
OffsetButtonPress,
|
||||
TextBox,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
|
||||
# minimum size for the UI
|
||||
MIN_COLS = 120
|
||||
MIN_LINES = 45
|
||||
|
||||
|
||||
class addModelsForm(npyscreen.FormMultiPage):
|
||||
# for responsive resizing - disabled
|
||||
#FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||
|
||||
# FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||
|
||||
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
||||
self.multipage = multipage
|
||||
self.initial_models = OmegaConf.load(Dataset_path)
|
||||
@ -71,13 +79,13 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
npyscreen.FixedText,
|
||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields,",
|
||||
editable=False,
|
||||
color='CAUTION',
|
||||
color="CAUTION",
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use cursor arrows to make a selection, and space to toggle checkboxes.",
|
||||
editable=False,
|
||||
color='CAUTION'
|
||||
color="CAUTION",
|
||||
)
|
||||
self.nextrely += 1
|
||||
if len(self.installed_models) > 0:
|
||||
@ -147,30 +155,26 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name='== IMPORT LOCAL AND REMOTE MODELS ==',
|
||||
name="== IMPORT LOCAL AND REMOTE MODELS ==",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
|
||||
for line in [
|
||||
"In the box below, enter URLs, file paths, or HuggingFace repository IDs.",
|
||||
"Separate model names by lines or whitespace (Use shift-control-V to paste):",
|
||||
"In the box below, enter URLs, file paths, or HuggingFace repository IDs.",
|
||||
"Separate model names by lines or whitespace (Use shift-control-V to paste):",
|
||||
]:
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name=line,
|
||||
editable=False,
|
||||
labelColor="CONTROL",
|
||||
relx = 4,
|
||||
relx=4,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.import_model_paths = self.add_widget_intelligent(
|
||||
TextBox,
|
||||
max_height=7,
|
||||
scroll_exit=True,
|
||||
editable=True,
|
||||
relx=4
|
||||
TextBox, max_height=7, scroll_exit=True, editable=True, relx=4
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.show_directory_fields = self.add_widget_intelligent(
|
||||
@ -245,7 +249,7 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
if hasattr(self,'models_selected'):
|
||||
if hasattr(self, "models_selected"):
|
||||
self.models_selected.values = self._get_starter_model_labels()
|
||||
|
||||
def _clear_scan_directory(self):
|
||||
@ -325,10 +329,11 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
selections = self.parentApp.user_selections
|
||||
|
||||
# starter models to install/remove
|
||||
if hasattr(self,'models_selected'):
|
||||
if hasattr(self, "models_selected"):
|
||||
starter_models = dict(
|
||||
map(
|
||||
lambda x: (self.starter_model_list[x], True), self.models_selected.value
|
||||
lambda x: (self.starter_model_list[x], True),
|
||||
self.models_selected.value,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -376,6 +381,7 @@ class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
"MAIN", addModelsForm, name="Install Stable Diffusion Models"
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def process_and_execute(opt: Namespace, selections: Namespace):
|
||||
models_to_remove = [
|
||||
@ -477,9 +483,9 @@ def main():
|
||||
print(
|
||||
">> Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||
)
|
||||
import ldm.invoke.config.invokeai_configure
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
|
||||
ldm.invoke.config.invokeai_configure.main()
|
||||
invokeai_configure()
|
||||
sys.exit(0)
|
||||
|
||||
try:
|
||||
@ -499,6 +505,7 @@ def main():
|
||||
"** Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,35 +1,39 @@
|
||||
'''
|
||||
"""
|
||||
Widget class definitions used by model_select.py, merge_diffusers.py and textual_inversion.py
|
||||
'''
|
||||
import math
|
||||
import platform
|
||||
import npyscreen
|
||||
import os
|
||||
import sys
|
||||
"""
|
||||
import curses
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import struct
|
||||
|
||||
import sys
|
||||
from shutil import get_terminal_size
|
||||
|
||||
import npyscreen
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def set_terminal_size(columns: int, lines: int):
|
||||
OS = platform.uname().system
|
||||
if OS=="Windows":
|
||||
os.system(f'mode con: cols={columns} lines={lines}')
|
||||
elif OS in ['Darwin', 'Linux']:
|
||||
import termios
|
||||
if OS == "Windows":
|
||||
os.system(f"mode con: cols={columns} lines={lines}")
|
||||
elif OS in ["Darwin", "Linux"]:
|
||||
import fcntl
|
||||
import termios
|
||||
|
||||
winsize = struct.pack("HHHH", lines, columns, 0, 0)
|
||||
fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize)
|
||||
sys.stdout.write("\x1b[8;{rows};{cols}t".format(rows=lines, cols=columns))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def set_min_terminal_size(min_cols: int, min_lines: int):
|
||||
# make sure there's enough room for the ui
|
||||
term_cols, term_lines = get_terminal_size()
|
||||
cols = max(term_cols, min_cols)
|
||||
cols = max(term_cols, min_cols)
|
||||
lines = max(term_lines, min_lines)
|
||||
set_terminal_size(cols,lines)
|
||||
set_terminal_size(cols, lines)
|
||||
|
||||
|
||||
class IntSlider(npyscreen.Slider):
|
||||
def translate_value(self):
|
||||
@ -38,18 +42,20 @@ class IntSlider(npyscreen.Slider):
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class CenteredTitleText(npyscreen.TitleText):
|
||||
def __init__(self,*args,**keywords):
|
||||
super().__init__(*args,**keywords)
|
||||
def __init__(self, *args, **keywords):
|
||||
super().__init__(*args, **keywords)
|
||||
self.resize()
|
||||
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
label = self.name
|
||||
self.relx = (maxx - len(label)) // 2
|
||||
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class CenteredButtonPress(npyscreen.ButtonPress):
|
||||
def resize(self):
|
||||
@ -57,21 +63,24 @@ class CenteredButtonPress(npyscreen.ButtonPress):
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
label = self.name
|
||||
self.relx = (maxx - len(label)) // 2
|
||||
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class OffsetButtonPress(npyscreen.ButtonPress):
|
||||
def __init__(self, screen, offset=0, *args, **keywords):
|
||||
def __init__(self, screen, offset=0, *args, **keywords):
|
||||
super().__init__(screen, *args, **keywords)
|
||||
self.offset = offset
|
||||
|
||||
|
||||
def resize(self):
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
width = len(self.name)
|
||||
self.relx = self.offset + (maxx - width) // 2
|
||||
|
||||
|
||||
class IntTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = IntSlider
|
||||
|
||||
|
||||
class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
@ -80,85 +89,114 @@ class FloatSlider(npyscreen.Slider):
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
|
||||
class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = FloatSlider
|
||||
|
||||
|
||||
class MultiSelectColumns(npyscreen.MultiSelect):
|
||||
def __init__(self, screen, columns: int=1, values: list=[], **keywords):
|
||||
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
super().__init__(screen,values=values, **keywords)
|
||||
super().__init__(screen, values=values, **keywords)
|
||||
|
||||
def make_contained_widgets(self):
|
||||
self._my_widgets = []
|
||||
column_width = self.width // self.columns
|
||||
for h in range(self.value_cnt):
|
||||
self._my_widgets.append(
|
||||
self._contained_widgets(self.parent,
|
||||
rely=self.rely + (h % self.rows) * self._contained_widget_height,
|
||||
relx=self.relx + (h // self.rows) * column_width,
|
||||
max_width=column_width,
|
||||
max_height=self.__class__._contained_widget_height,
|
||||
)
|
||||
self._contained_widgets(
|
||||
self.parent,
|
||||
rely=self.rely + (h % self.rows) * self._contained_widget_height,
|
||||
relx=self.relx + (h // self.rows) * column_width,
|
||||
max_width=column_width,
|
||||
max_height=self.__class__._contained_widget_height,
|
||||
)
|
||||
)
|
||||
|
||||
def set_up_handlers(self):
|
||||
super().set_up_handlers()
|
||||
self.handlers.update({
|
||||
curses.KEY_UP: self.h_cursor_line_left,
|
||||
curses.KEY_DOWN: self.h_cursor_line_right,
|
||||
}
|
||||
)
|
||||
self.handlers.update(
|
||||
{
|
||||
curses.KEY_UP: self.h_cursor_line_left,
|
||||
curses.KEY_DOWN: self.h_cursor_line_right,
|
||||
}
|
||||
)
|
||||
|
||||
def h_cursor_line_down(self, ch):
|
||||
self.cursor_line += self.rows
|
||||
if self.cursor_line >= len(self.values):
|
||||
if self.scroll_exit:
|
||||
self.cursor_line = len(self.values)-self.rows
|
||||
if self.scroll_exit:
|
||||
self.cursor_line = len(self.values) - self.rows
|
||||
self.h_exit_down(ch)
|
||||
return True
|
||||
else:
|
||||
else:
|
||||
self.cursor_line -= self.rows
|
||||
return True
|
||||
|
||||
def h_cursor_line_up(self, ch):
|
||||
self.cursor_line -= self.rows
|
||||
if self.cursor_line < 0:
|
||||
if self.cursor_line < 0:
|
||||
if self.scroll_exit:
|
||||
self.cursor_line = 0
|
||||
self.h_exit_up(ch)
|
||||
else:
|
||||
else:
|
||||
self.cursor_line = 0
|
||||
|
||||
def h_cursor_line_left(self,ch):
|
||||
def h_cursor_line_left(self, ch):
|
||||
super().h_cursor_line_up(ch)
|
||||
|
||||
def h_cursor_line_right(self,ch):
|
||||
|
||||
def h_cursor_line_right(self, ch):
|
||||
super().h_cursor_line_down(ch)
|
||||
|
||||
|
||||
class TextBox(npyscreen.MultiLineEdit):
|
||||
def update(self, clear=True):
|
||||
if clear: self.clear()
|
||||
if clear:
|
||||
self.clear()
|
||||
|
||||
HEIGHT = self.height
|
||||
WIDTH = self.width
|
||||
WIDTH = self.width
|
||||
# draw box.
|
||||
self.parent.curses_pad.hline(self.rely, self.relx, curses.ACS_HLINE, WIDTH)
|
||||
self.parent.curses_pad.hline(self.rely + HEIGHT, self.relx, curses.ACS_HLINE, WIDTH)
|
||||
self.parent.curses_pad.vline(self.rely, self.relx, curses.ACS_VLINE, self.height)
|
||||
self.parent.curses_pad.vline(self.rely, self.relx+WIDTH, curses.ACS_VLINE, HEIGHT)
|
||||
|
||||
self.parent.curses_pad.hline(
|
||||
self.rely + HEIGHT, self.relx, curses.ACS_HLINE, WIDTH
|
||||
)
|
||||
self.parent.curses_pad.vline(
|
||||
self.rely, self.relx, curses.ACS_VLINE, self.height
|
||||
)
|
||||
self.parent.curses_pad.vline(
|
||||
self.rely, self.relx + WIDTH, curses.ACS_VLINE, HEIGHT
|
||||
)
|
||||
|
||||
# draw corners
|
||||
self.parent.curses_pad.addch(self.rely, self.relx, curses.ACS_ULCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely, self.relx+WIDTH, curses.ACS_URCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely+HEIGHT, self.relx, curses.ACS_LLCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely+HEIGHT, self.relx+WIDTH, curses.ACS_LRCORNER, )
|
||||
|
||||
self.parent.curses_pad.addch(
|
||||
self.rely,
|
||||
self.relx,
|
||||
curses.ACS_ULCORNER,
|
||||
)
|
||||
self.parent.curses_pad.addch(
|
||||
self.rely,
|
||||
self.relx + WIDTH,
|
||||
curses.ACS_URCORNER,
|
||||
)
|
||||
self.parent.curses_pad.addch(
|
||||
self.rely + HEIGHT,
|
||||
self.relx,
|
||||
curses.ACS_LLCORNER,
|
||||
)
|
||||
self.parent.curses_pad.addch(
|
||||
self.rely + HEIGHT,
|
||||
self.relx + WIDTH,
|
||||
curses.ACS_LRCORNER,
|
||||
)
|
||||
|
||||
# fool our superclass into thinking drawing area is smaller - this is really hacky but it seems to work
|
||||
(relx,rely,height,width) = (self.relx, self.rely, self.height, self.width)
|
||||
(relx, rely, height, width) = (self.relx, self.rely, self.height, self.width)
|
||||
self.relx += 1
|
||||
self.rely += 1
|
||||
self.height -= 1
|
||||
self.width -= 1
|
||||
super().update(clear=False)
|
||||
(self.relx,self.rely,self.height,self.width) = (relx, rely, height, width)
|
||||
(self.relx, self.rely, self.height, self.width) = (relx, rely, height, width)
|
@ -1,4 +1,4 @@
|
||||
'''
|
||||
"""
|
||||
Initialization file for invokeai.frontend.merge
|
||||
'''
|
||||
"""
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
|
||||
invokeai.frontend.merge exports a single function call merge_diffusion_models()
|
||||
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
@ -20,13 +20,19 @@ from diffusers import logging as dlogging
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ...frontend.config.widgets import FloatTitleSlider
|
||||
from ...backend.globals import (Globals, global_cache_dir, global_config_file,
|
||||
global_models_dir, global_set_root)
|
||||
from ...backend.globals import (
|
||||
Globals,
|
||||
global_cache_dir,
|
||||
global_config_file,
|
||||
global_models_dir,
|
||||
global_set_root,
|
||||
)
|
||||
from ...backend.model_management import ModelManager
|
||||
from ...frontend.install.widgets import FloatTitleSlider
|
||||
|
||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
||||
|
||||
|
||||
def merge_diffusion_models(
|
||||
model_ids_or_paths: List[Union[str, Path]],
|
||||
alpha: float = 0.5,
|
||||
@ -310,8 +316,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
self.merged_model_name.value = merged_model_name
|
||||
|
||||
if selected_model3 > 0:
|
||||
self.merge_method.values = ['add_difference ( A+(B-C) )']
|
||||
self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one.
|
||||
self.merge_method.values = ["add_difference ( A+(B-C) )"]
|
||||
self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one.
|
||||
else:
|
||||
self.merge_method.values = self.interpolations
|
||||
self.merge_method.value = 0
|
||||
@ -336,9 +342,9 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
]
|
||||
if self.model3.value[0] > 0:
|
||||
models.append(model_names[self.model3.value[0] - 1])
|
||||
interp='add_difference'
|
||||
interp = "add_difference"
|
||||
else:
|
||||
interp=self.interpolations[self.merge_method.value[0]]
|
||||
interp = self.interpolations[self.merge_method.value[0]]
|
||||
|
||||
args = dict(
|
||||
models=models,
|
||||
@ -453,7 +459,9 @@ def main():
|
||||
"** You need to have at least two diffusers models defined in models.yaml in order to merge"
|
||||
)
|
||||
else:
|
||||
print("** Not enough room for the user interface. Try making this window larger.")
|
||||
print(
|
||||
"** Not enough room for the user interface. Try making this window larger."
|
||||
)
|
||||
sys.exit(-1)
|
||||
except Exception:
|
||||
print(">> An error occurred:")
|
||||
|
@ -1,5 +1,4 @@
|
||||
'''
|
||||
"""
|
||||
Initialization file for invokeai.frontend.training
|
||||
'''
|
||||
"""
|
||||
from .textual_inversion import main as invokeai_textual_inversion
|
||||
|
||||
|
@ -21,10 +21,8 @@ from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from invokeai.backend.globals import Globals, global_set_root
|
||||
from ...backend.training import (
|
||||
do_textual_inversion_training,
|
||||
parse_args,
|
||||
)
|
||||
|
||||
from ...backend.training import do_textual_inversion_training, parse_args
|
||||
|
||||
TRAINING_DATA = "text-inversion-training-data"
|
||||
TRAINING_DIR = "text-inversion-output"
|
||||
@ -448,9 +446,9 @@ def main():
|
||||
print(
|
||||
"** You need to have at least one diffusers models defined in models.yaml in order to train"
|
||||
)
|
||||
elif str(e).startswith('addwstr'):
|
||||
elif str(e).startswith("addwstr"):
|
||||
print(
|
||||
'** Not enough window space for the interface. Please make your window larger and try again.'
|
||||
"** Not enough window space for the interface. Please make your window larger and try again."
|
||||
)
|
||||
else:
|
||||
print(f"** An error has occurred: {str(e)}")
|
||||
|
Reference in New Issue
Block a user