mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model installer frontend done - needs to be hooked to backend
This commit is contained in:
parent
f299f40763
commit
e87a2fe14b
@ -3,79 +3,27 @@
|
||||
# Before running stable-diffusion on an internet-isolated machine,
|
||||
# run this script from one with internet connectivity. The
|
||||
# two machines must share a common .cache directory.
|
||||
#
|
||||
# Coauthor: Kevin Turner http://github.com/keturn
|
||||
#
|
||||
|
||||
import argparse
|
||||
import curses
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryFile
|
||||
from typing import List
|
||||
|
||||
import npyscreen
|
||||
import requests
|
||||
from diffusers import AutoencoderKL
|
||||
from huggingface_hub import hf_hub_url
|
||||
import torch
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
from ldm.invoke.devices import choose_precision, choose_torch_device
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import Globals, global_cache_dir, global_config_dir
|
||||
from ldm.invoke.config.widgets import MultiSelectColumns
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import torch
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
|
||||
Default_config_file = Path(global_config_dir()) / "models.yaml"
|
||||
SD_Configs = Path(global_config_dir()) / "stable-diffusion"
|
||||
|
||||
Datasets = OmegaConf.load(Dataset_path)
|
||||
|
||||
Config_preamble = """# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
"""
|
||||
|
||||
# -------------------------------------
|
||||
def yes_or_no(prompt: str, default_yes=True):
|
||||
default = "y" if default_yes else "n"
|
||||
response = input(f"{prompt} [{default}] ") or default
|
||||
if default_yes:
|
||||
return response[0] not in ("n", "N")
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
if root:
|
||||
return root
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
else:
|
||||
return Globals.root
|
||||
|
||||
from ..devices import choose_precision, choose_torch_device
|
||||
from ..globals import Globals
|
||||
from .widgets import MultiSelectColumns, TextBox
|
||||
from .model_install_util import (Dataset_path, Default_config_file,
|
||||
default_dataset, download_weight_datasets,
|
||||
update_config_file, get_root
|
||||
)
|
||||
|
||||
class addModelsForm(npyscreen.FormMultiPageAction):
|
||||
def __init__(self, parentApp, name):
|
||||
@ -98,38 +46,53 @@ class addModelsForm(npyscreen.FormMultiPageAction):
|
||||
for x in self.starter_model_list
|
||||
if self.initial_models[x].get("recommended", False)
|
||||
]
|
||||
previously_installed_models = sorted(
|
||||
self.installed_models = sorted(
|
||||
[
|
||||
x for x in list(self.initial_models.keys()) if x in self.existing_models
|
||||
]
|
||||
)
|
||||
|
||||
if len(previously_installed_models) > 0:
|
||||
title = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Currently installed starter models. Uncheck to delete:",
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value='Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields,',
|
||||
editable=False,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value='cursor arrows to make a selection, and space to toggle checkboxes.',
|
||||
editable=False,
|
||||
)
|
||||
|
||||
if len(self.installed_models) > 0:
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="== INSTALLED STARTER MODELS ==",
|
||||
value="Currently installed starter models. Uncheck to delete:",
|
||||
begin_entry_at=2,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
columns = 3
|
||||
columns = self._get_columns()
|
||||
self.previously_installed_models = self.add_widget_intelligent(
|
||||
MultiSelectColumns,
|
||||
columns=columns,
|
||||
values=previously_installed_models,
|
||||
value=[x for x in range(0,len(previously_installed_models))],
|
||||
max_height=len(previously_installed_models)+1 // columns,
|
||||
values=self.installed_models,
|
||||
value=[x for x in range(0,len(self.installed_models))],
|
||||
max_height=2+len(self.installed_models) // columns,
|
||||
relx = 4,
|
||||
slow_scroll=True,
|
||||
scroll_exit = True,
|
||||
)
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Select from a starter set of Stable Diffusion models from HuggingFace:",
|
||||
npyscreen.TitleFixedText,
|
||||
name="== UNINSTALLED STARTER MODELS ==",
|
||||
value="Select from a starter set of Stable Diffusion models from HuggingFace:",
|
||||
begin_entry_at=2,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely -= 2
|
||||
self.add_widget_intelligent(npyscreen.FixedText, value="", editable=False),
|
||||
self.nextrely -= 1
|
||||
self.models_selected = self.add_widget_intelligent(
|
||||
npyscreen.MultiSelect,
|
||||
name="Install Starter Models",
|
||||
@ -140,39 +103,39 @@ class addModelsForm(npyscreen.FormMultiPageAction):
|
||||
if x in recommended_models
|
||||
],
|
||||
max_height=len(starter_model_labels) + 1,
|
||||
relx = 4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
for line in [
|
||||
'Import checkpoint/safetensor models from the directory below.',
|
||||
'(Use <tab> to autocomplete)'
|
||||
]:
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name=line,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name='== MODEL IMPORT DIRECTORY ==',
|
||||
value='Import all models found in this directory (<tab> autocompletes):',
|
||||
begin_entry_at=2,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.autoload_directory = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name='Directory:',
|
||||
select_dir=True,
|
||||
must_exist=True,
|
||||
use_two_lines=False,
|
||||
value=os.path.expanduser('~'+'/'),
|
||||
relx = 4,
|
||||
labelColor='DANGER',
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.autoload_onstartup = self.add_widget_intelligent(
|
||||
self.autoscan_on_startup = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name='Scan this directory each time InvokeAI starts for new models to import.',
|
||||
value=False,
|
||||
relx = 4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
for line in [
|
||||
'In the space below, you may cut and paste URLs, paths to .ckpt/.safetensor files',
|
||||
'or HuggingFace diffusers repository names to import.',
|
||||
'(Use control-V or shift-control-V to paste):'
|
||||
'== INDIVIDUAL MODELS TO IMPORT ==',
|
||||
'Enter list of URLs, paths models or HuggingFace diffusers repository IDs.',
|
||||
'Use control-V or shift-control-V to paste:'
|
||||
]:
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
@ -181,27 +144,33 @@ class addModelsForm(npyscreen.FormMultiPageAction):
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.model_names = self.add_widget_intelligent(
|
||||
npyscreen.MultiLineEdit,
|
||||
max_width=75,
|
||||
self.import_model_paths = self.add_widget_intelligent(
|
||||
TextBox,
|
||||
max_height=8,
|
||||
scroll_exit=True,
|
||||
relx=3
|
||||
editable=True,
|
||||
relx=4
|
||||
)
|
||||
self.autoload_onstartup = self.add_widget_intelligent(
|
||||
self.nextrely += 2
|
||||
self.convert_models = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Keep files in original format, or convert .ckpt/.safetensors into fast-loading diffusers models:',
|
||||
values=['Original format','Convert to diffusers format'],
|
||||
name='== CONVERT IMPORTED MODELS INTO DIFFUSERS==',
|
||||
values=['Keep original format','Convert to diffusers'],
|
||||
value=0,
|
||||
begin_entry_at=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.find_next_editable()
|
||||
# self.set_editing(self.models_selected)
|
||||
# self.display()
|
||||
# self.models_selected.editing=True
|
||||
# self.models_selected.edit()
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
self.models_selected.values = self._get_starter_model_labels()
|
||||
# thought this would dynamically resize the widget, but no luck
|
||||
# self.previously_installed_models.columns = self._get_columns()
|
||||
# self.previously_installed_models.max_height = 2+len(self.installed_models) // self._get_columns()
|
||||
# self.previously_installed_models.make_contained_widgets()
|
||||
# self.previously_installed_models.display()
|
||||
|
||||
def _get_starter_model_labels(self):
|
||||
def _get_starter_model_labels(self)->List[str]:
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
label_width = 25
|
||||
checkbox_width = 4
|
||||
@ -217,17 +186,89 @@ class addModelsForm(npyscreen.FormMultiPageAction):
|
||||
f"%-{label_width}s %s" % (names[x], descriptions[x]) for x in range(0,len(im))
|
||||
]
|
||||
|
||||
def _get_columns(self)->int:
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
return 4 if window_width > 240 else 3 if window_width>160 else 2 if window_width>80 else 1
|
||||
|
||||
def on_ok(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
self.parentApp.setNextForm('MONITOR_OUTPUT')
|
||||
self.editing = False
|
||||
self.parentApp.selected_models = [
|
||||
self.starter_model_list[x] for x in self.models_selected.value
|
||||
]
|
||||
npyscreen.notify(f"Installing selected {self.parentApp.selected_models}")
|
||||
self.parentApp.user_cancelled = False
|
||||
self.marshall_arguments()
|
||||
|
||||
def on_cancel(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
self.parentApp.selected_models = None
|
||||
self.ParentApp.user_cancelled = True
|
||||
self.editing = False
|
||||
|
||||
def marshall_arguments(self):
|
||||
'''
|
||||
Assemble arguments and store as attributes of the application:
|
||||
.starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml
|
||||
True => Install
|
||||
False => Remove
|
||||
.scan_directory: Path to a directory of models to scan and import
|
||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
|
||||
'''
|
||||
# starter models to install/remove
|
||||
model_names = list(self.initial_models.keys())
|
||||
starter_models = dict(map(lambda x: (model_names[x], True), self.models_selected.value))
|
||||
if hasattr(self,'previously_installed_models'):
|
||||
unchecked = [
|
||||
self.previously_installed_models.values[x]
|
||||
for x in range(0,len(self.previously_installed_models.values))
|
||||
if x not in self.previously_installed_models.value
|
||||
]
|
||||
starter_models.update(
|
||||
map(lambda x: (x, False), unchecked)
|
||||
)
|
||||
self.parentApp.starter_models=starter_models
|
||||
|
||||
# load directory and whether to scan on startup
|
||||
self.parentApp.scan_directory = self.autoload_directory.value
|
||||
self.parentApp.autoscan_on_startup = self.autoscan_on_startup.value
|
||||
|
||||
# URLs and the like
|
||||
self.parentApp.import_model_paths = self.import_model_paths.value.split()
|
||||
self.parentApp.convert_to_diffusers = self.convert_models.value != 0
|
||||
|
||||
class Log(object):
|
||||
def __init__(self, writable):
|
||||
self.writable = writable
|
||||
|
||||
def __enter__(self):
|
||||
self._stdout = sys.stdout
|
||||
sys.stdout = self.writable
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
sys.stdout = self._stdout
|
||||
|
||||
class outputForm(npyscreen.ActionForm):
|
||||
def create(self):
|
||||
self.buffer = self.add_widget(
|
||||
npyscreen.BufferPager,
|
||||
editable=False,
|
||||
)
|
||||
|
||||
def write(self,string):
|
||||
if string != '\n':
|
||||
self.buffer.buffer([string])
|
||||
|
||||
def beforeEditing(self):
|
||||
myapplication = self.parentApp
|
||||
with Log(self):
|
||||
print(f'DEBUG: these models will be removed: {[x for x in myapplication.starter_models if not myapplication.starter_models[x]]}')
|
||||
print(f'DEBUG: these models will be installed: {[x for x in myapplication.starter_models if myapplication.starter_models[x]]}')
|
||||
print(f'DEBUG: this directory will be scanned: {myapplication.scan_directory}')
|
||||
print(f'DEBUG: scan at startup time? {myapplication.autoscan_on_startup}')
|
||||
print(f'DEBUG: these things will be downloaded: {myapplication.import_model_paths}')
|
||||
print(f'DEBUG: convert to diffusers? {myapplication.convert_to_diffusers}')
|
||||
|
||||
def on_ok(self):
|
||||
self.buffer.buffer(['goodbye!'])
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
@ -242,316 +283,12 @@ class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
addModelsForm,
|
||||
name="Add/Remove Models",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def recommended_datasets() -> dict:
|
||||
datasets = dict()
|
||||
for ds in Datasets.keys():
|
||||
if Datasets[ds].get("recommended", False):
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def default_dataset() -> dict:
|
||||
datasets = dict()
|
||||
for ds in Datasets.keys():
|
||||
if Datasets[ds].get("default", False):
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def all_datasets() -> dict:
|
||||
datasets = dict()
|
||||
for ds in Datasets.keys():
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
# look for legacy model.ckpt in models directory and offer to
|
||||
# normalize its name
|
||||
def migrate_models_ckpt():
|
||||
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
||||
return
|
||||
new_name = Datasets["stable-diffusion-1.4"]["file"]
|
||||
print('You seem to have the Stable Diffusion v4.1 "model.ckpt" already installed.')
|
||||
rename = yes_or_no(f'Ok to rename it to "{new_name}" for future reference?')
|
||||
if rename:
|
||||
print(f"model.ckpt => {new_name}")
|
||||
os.replace(
|
||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
||||
self.output = self.addForm(
|
||||
'MONITOR_OUTPUT',
|
||||
outputForm,
|
||||
name='Model Install Output'
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_weight_datasets(
|
||||
models: dict, access_token: str, precision: str = "float32"
|
||||
):
|
||||
migrate_models_ckpt()
|
||||
successful = dict()
|
||||
for mod in models.keys():
|
||||
print(f"Downloading {mod}:")
|
||||
successful[mod] = _download_repo_or_file(
|
||||
Datasets[mod], access_token, precision=precision
|
||||
)
|
||||
return successful
|
||||
|
||||
|
||||
def _download_repo_or_file(
|
||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
||||
) -> Path:
|
||||
path = None
|
||||
if mconfig["format"] == "ckpt":
|
||||
path = _download_ckpt_weights(mconfig, access_token)
|
||||
else:
|
||||
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
|
||||
if "vae" in mconfig and "repo_id" in mconfig["vae"]:
|
||||
_download_diffusion_weights(
|
||||
mconfig["vae"], access_token, precision=precision
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
repo_id = mconfig["repo_id"]
|
||||
filename = mconfig["file"]
|
||||
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
return hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_dir=cache_dir,
|
||||
model_name=filename,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_from_hf(
|
||||
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
|
||||
):
|
||||
print("", file=sys.stderr) # to prevent tqdm from overwriting
|
||||
path = global_cache_dir(cache_subdir)
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=path,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model_name = "--".join(("models", *model_name.split("/")))
|
||||
return path / model_name if model else None
|
||||
|
||||
|
||||
def _download_diffusion_weights(
|
||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
||||
):
|
||||
repo_id = mconfig["repo_id"]
|
||||
model_class = (
|
||||
StableDiffusionGeneratorPipeline
|
||||
if mconfig.get("format", None) == "diffusers"
|
||||
else AutoencoderKL
|
||||
)
|
||||
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
|
||||
path = None
|
||||
for extra_args in extra_arg_list:
|
||||
try:
|
||||
path = download_from_hf(
|
||||
model_class,
|
||||
repo_id,
|
||||
cache_subdir="diffusers",
|
||||
safety_checker=None,
|
||||
**extra_args,
|
||||
)
|
||||
except OSError as e:
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
pass
|
||||
else:
|
||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
||||
if path:
|
||||
break
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_with_resume(
|
||||
repo_id: str, model_dir: str, model_name: str, access_token: str = None
|
||||
) -> Path:
|
||||
model_dest = Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name)
|
||||
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
if os.path.exists(model_dest):
|
||||
exist_size = os.path.getsize(model_dest)
|
||||
header["Range"] = f"bytes={exist_size}-"
|
||||
open_mode = "ab"
|
||||
|
||||
resp = requests.get(url, headers=header, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
print(f"* {model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code != 200:
|
||||
print(f"** An error occurred during downloading {model_name}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
print(f"* {model_name}: partial file found. Resuming...")
|
||||
else:
|
||||
print(f"* {model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
if total < 2000:
|
||||
print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}")
|
||||
return None
|
||||
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar:
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
return None
|
||||
return model_dest
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def update_config_file(successfully_downloaded: dict, opt: dict):
|
||||
config_file = (
|
||||
Path(opt.config_file) if opt.config_file is not None else Default_config_file
|
||||
)
|
||||
|
||||
# In some cases (incomplete setup, etc), the default configs directory might be missing.
|
||||
# Create it if it doesn't exist.
|
||||
# this check is ignored if opt.config_file is specified - user is assumed to know what they
|
||||
# are doing if they are passing a custom config file from elsewhere.
|
||||
if config_file is Default_config_file and not config_file.parent.exists():
|
||||
configs_src = Dataset_path.parent
|
||||
configs_dest = Default_config_file.parent
|
||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||
|
||||
yaml = new_config_file_contents(successfully_downloaded, config_file, opt)
|
||||
|
||||
try:
|
||||
backup = None
|
||||
if os.path.exists(config_file):
|
||||
print(
|
||||
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
)
|
||||
backup = config_file.with_suffix(".yaml.orig")
|
||||
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
||||
if sys.platform == "win32" and backup.is_file():
|
||||
backup.unlink()
|
||||
config_file.rename(backup)
|
||||
|
||||
with TemporaryFile() as tmp:
|
||||
tmp.write(Config_preamble.encode())
|
||||
tmp.write(yaml.encode())
|
||||
|
||||
with open(str(config_file.expanduser().resolve()), "wb") as new_config:
|
||||
tmp.seek(0)
|
||||
new_config.write(tmp.read())
|
||||
|
||||
except Exception as e:
|
||||
print(f"**Error creating config file {config_file}: {str(e)} **")
|
||||
if backup is not None:
|
||||
print("restoring previous config file")
|
||||
## workaround, for WinError 183, see above
|
||||
if sys.platform == "win32" and config_file.is_file():
|
||||
config_file.unlink()
|
||||
backup.rename(config_file)
|
||||
return
|
||||
|
||||
print(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def new_config_file_contents(
|
||||
successfully_downloaded: dict, config_file: Path, opt: dict
|
||||
) -> str:
|
||||
if config_file.exists():
|
||||
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
|
||||
else:
|
||||
conf = OmegaConf.create()
|
||||
|
||||
default_selected = None
|
||||
for model in successfully_downloaded:
|
||||
# a bit hacky - what we are doing here is seeing whether a checkpoint
|
||||
# version of the model was previously defined, and whether the current
|
||||
# model is a diffusers (indicated with a path)
|
||||
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
|
||||
offer_to_delete_weights(model, conf[model], opt.yes_to_all)
|
||||
|
||||
stanza = {}
|
||||
mod = Datasets[model]
|
||||
stanza["description"] = mod["description"]
|
||||
stanza["repo_id"] = mod["repo_id"]
|
||||
stanza["format"] = mod["format"]
|
||||
# diffusers don't need width and height (probably .ckpt doesn't either)
|
||||
# so we no longer require these in INITIAL_MODELS.yaml
|
||||
if "width" in mod:
|
||||
stanza["width"] = mod["width"]
|
||||
if "height" in mod:
|
||||
stanza["height"] = mod["height"]
|
||||
if "file" in mod:
|
||||
stanza["weights"] = os.path.relpath(
|
||||
successfully_downloaded[model], start=Globals.root
|
||||
)
|
||||
stanza["config"] = os.path.normpath(os.path.join(SD_Configs, mod["config"]))
|
||||
if "vae" in mod:
|
||||
if "file" in mod["vae"]:
|
||||
stanza["vae"] = os.path.normpath(
|
||||
os.path.join(Model_dir, Weights_dir, mod["vae"]["file"])
|
||||
)
|
||||
else:
|
||||
stanza["vae"] = mod["vae"]
|
||||
if mod.get("default", False):
|
||||
stanza["default"] = True
|
||||
default_selected = True
|
||||
|
||||
conf[model] = stanza
|
||||
|
||||
# if no default model was chosen, then we select the first
|
||||
# one in the list
|
||||
if not default_selected:
|
||||
conf[list(successfully_downloaded.keys())[0]]["default"] = True
|
||||
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def offer_to_delete_weights(model_name: str, conf_stanza: dict, yes_to_all: bool):
|
||||
if not (weights := conf_stanza.get("weights")):
|
||||
return
|
||||
if re.match("/VAE/", conf_stanza.get("config")):
|
||||
return
|
||||
if yes_to_all or yes_or_no(
|
||||
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Delete the original file {weights}?",
|
||||
default_yes=False,
|
||||
):
|
||||
weights = Path(weights)
|
||||
if not weights.is_absolute():
|
||||
weights = Path(Globals.root) / weights
|
||||
try:
|
||||
weights.unlink()
|
||||
except OSError as e:
|
||||
print(str(e))
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def select_and_download_models(opt: Namespace):
|
||||
if opt.default_only:
|
||||
@ -559,7 +296,14 @@ def select_and_download_models(opt: Namespace):
|
||||
else:
|
||||
myapplication = AddModelApplication()
|
||||
myapplication.run()
|
||||
models_to_download = dict(map(lambda x: (x, True), myapplication.selected_models)) if myapplication.selected_models else None
|
||||
if not myapplication.user_cancelled:
|
||||
print(f'DEBUG: these models will be removed: {[x for x in myapplication.starter_models if not myapplication.starter_models[x]]}')
|
||||
print(f'DEBUG: these models will be installed: {[x for x in myapplication.starter_models if myapplication.starter_models[x]]}')
|
||||
print(f'DEBUG: this directory will be scanned: {myapplication.scan_directory}')
|
||||
print(f'DEBUG: scan at startup time? {myapplication.autoscan_on_startup}')
|
||||
print(f'DEBUG: these things will be downloaded: {myapplication.import_model_paths}')
|
||||
print(f'DEBUG: convert to diffusers? {myapplication.convert_to_diffusers}')
|
||||
sys.exit(0)
|
||||
|
||||
if not models_to_download:
|
||||
print(
|
||||
@ -649,6 +393,7 @@ def main():
|
||||
)
|
||||
else:
|
||||
print(f"** A layout error has occurred: {str(e)}")
|
||||
traceback.print_exc()
|
||||
sys.exit(-1)
|
||||
|
||||
# -------------------------------------
|
||||
|
378
ldm/invoke/config/model_install_util.py
Normal file
378
ldm/invoke/config/model_install_util.py
Normal file
@ -0,0 +1,378 @@
|
||||
'''
|
||||
Utility (backend) functions used by model_install.py
|
||||
'''
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryFile
|
||||
|
||||
import npyscreen
|
||||
import requests
|
||||
from diffusers import AutoencoderKL
|
||||
from huggingface_hub import hf_hub_url
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
from ldm.invoke.devices import choose_precision, choose_torch_device
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import Globals, global_cache_dir, global_config_dir
|
||||
from ldm.invoke.config.widgets import MultiSelectColumns
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import torch
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
|
||||
Default_config_file = Path(global_config_dir()) / "models.yaml"
|
||||
SD_Configs = Path(global_config_dir()) / "stable-diffusion"
|
||||
|
||||
Datasets = OmegaConf.load(Dataset_path)
|
||||
|
||||
Config_preamble = """# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
"""
|
||||
|
||||
# -------------------------------------
|
||||
def yes_or_no(prompt: str, default_yes=True):
|
||||
default = "y" if default_yes else "n"
|
||||
response = input(f"{prompt} [{default}] ") or default
|
||||
if default_yes:
|
||||
return response[0] not in ("n", "N")
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
if root:
|
||||
return root
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
else:
|
||||
return Globals.root
|
||||
|
||||
# ---------------------------------------------
|
||||
def recommended_datasets() -> dict:
|
||||
datasets = dict()
|
||||
for ds in Datasets.keys():
|
||||
if Datasets[ds].get("recommended", False):
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def default_dataset() -> dict:
|
||||
datasets = dict()
|
||||
for ds in Datasets.keys():
|
||||
if Datasets[ds].get("default", False):
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def all_datasets() -> dict:
|
||||
datasets = dict()
|
||||
for ds in Datasets.keys():
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
# look for legacy model.ckpt in models directory and offer to
|
||||
# normalize its name
|
||||
def migrate_models_ckpt():
|
||||
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
||||
return
|
||||
new_name = Datasets["stable-diffusion-1.4"]["file"]
|
||||
print('You seem to have the Stable Diffusion v4.1 "model.ckpt" already installed.')
|
||||
rename = yes_or_no(f'Ok to rename it to "{new_name}" for future reference?')
|
||||
if rename:
|
||||
print(f"model.ckpt => {new_name}")
|
||||
os.replace(
|
||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_weight_datasets(
|
||||
models: dict, access_token: str, precision: str = "float32"
|
||||
):
|
||||
migrate_models_ckpt()
|
||||
successful = dict()
|
||||
for mod in models.keys():
|
||||
print(f"Downloading {mod}:")
|
||||
successful[mod] = _download_repo_or_file(
|
||||
Datasets[mod], access_token, precision=precision
|
||||
)
|
||||
return successful
|
||||
|
||||
|
||||
def _download_repo_or_file(
|
||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
||||
) -> Path:
|
||||
path = None
|
||||
if mconfig["format"] == "ckpt":
|
||||
path = _download_ckpt_weights(mconfig, access_token)
|
||||
else:
|
||||
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
|
||||
if "vae" in mconfig and "repo_id" in mconfig["vae"]:
|
||||
_download_diffusion_weights(
|
||||
mconfig["vae"], access_token, precision=precision
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
repo_id = mconfig["repo_id"]
|
||||
filename = mconfig["file"]
|
||||
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
return hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_dir=cache_dir,
|
||||
model_name=filename,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_from_hf(
|
||||
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
|
||||
):
|
||||
print("", file=sys.stderr) # to prevent tqdm from overwriting
|
||||
path = global_cache_dir(cache_subdir)
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=path,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model_name = "--".join(("models", *model_name.split("/")))
|
||||
return path / model_name if model else None
|
||||
|
||||
|
||||
def _download_diffusion_weights(
|
||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
||||
):
|
||||
repo_id = mconfig["repo_id"]
|
||||
model_class = (
|
||||
StableDiffusionGeneratorPipeline
|
||||
if mconfig.get("format", None) == "diffusers"
|
||||
else AutoencoderKL
|
||||
)
|
||||
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
|
||||
path = None
|
||||
for extra_args in extra_arg_list:
|
||||
try:
|
||||
path = download_from_hf(
|
||||
model_class,
|
||||
repo_id,
|
||||
cache_subdir="diffusers",
|
||||
safety_checker=None,
|
||||
**extra_args,
|
||||
)
|
||||
except OSError as e:
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
pass
|
||||
else:
|
||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
||||
if path:
|
||||
break
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_with_resume(
|
||||
repo_id: str, model_dir: str, model_name: str, access_token: str = None
|
||||
) -> Path:
|
||||
model_dest = Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name)
|
||||
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
if os.path.exists(model_dest):
|
||||
exist_size = os.path.getsize(model_dest)
|
||||
header["Range"] = f"bytes={exist_size}-"
|
||||
open_mode = "ab"
|
||||
|
||||
resp = requests.get(url, headers=header, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
print(f"* {model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code != 200:
|
||||
print(f"** An error occurred during downloading {model_name}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
print(f"* {model_name}: partial file found. Resuming...")
|
||||
else:
|
||||
print(f"* {model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
if total < 2000:
|
||||
print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}")
|
||||
return None
|
||||
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar:
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
return None
|
||||
return model_dest
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def update_config_file(successfully_downloaded: dict, opt: dict):
|
||||
config_file = (
|
||||
Path(opt.config_file) if opt.config_file is not None else Default_config_file
|
||||
)
|
||||
|
||||
# In some cases (incomplete setup, etc), the default configs directory might be missing.
|
||||
# Create it if it doesn't exist.
|
||||
# this check is ignored if opt.config_file is specified - user is assumed to know what they
|
||||
# are doing if they are passing a custom config file from elsewhere.
|
||||
if config_file is Default_config_file and not config_file.parent.exists():
|
||||
configs_src = Dataset_path.parent
|
||||
configs_dest = Default_config_file.parent
|
||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||
|
||||
yaml = new_config_file_contents(successfully_downloaded, config_file, opt)
|
||||
|
||||
try:
|
||||
backup = None
|
||||
if os.path.exists(config_file):
|
||||
print(
|
||||
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
)
|
||||
backup = config_file.with_suffix(".yaml.orig")
|
||||
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
||||
if sys.platform == "win32" and backup.is_file():
|
||||
backup.unlink()
|
||||
config_file.rename(backup)
|
||||
|
||||
with TemporaryFile() as tmp:
|
||||
tmp.write(Config_preamble.encode())
|
||||
tmp.write(yaml.encode())
|
||||
|
||||
with open(str(config_file.expanduser().resolve()), "wb") as new_config:
|
||||
tmp.seek(0)
|
||||
new_config.write(tmp.read())
|
||||
|
||||
except Exception as e:
|
||||
print(f"**Error creating config file {config_file}: {str(e)} **")
|
||||
if backup is not None:
|
||||
print("restoring previous config file")
|
||||
## workaround, for WinError 183, see above
|
||||
if sys.platform == "win32" and config_file.is_file():
|
||||
config_file.unlink()
|
||||
backup.rename(config_file)
|
||||
return
|
||||
|
||||
print(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def new_config_file_contents(
|
||||
successfully_downloaded: dict, config_file: Path, opt: dict
|
||||
) -> str:
|
||||
if config_file.exists():
|
||||
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
|
||||
else:
|
||||
conf = OmegaConf.create()
|
||||
|
||||
default_selected = None
|
||||
for model in successfully_downloaded:
|
||||
# a bit hacky - what we are doing here is seeing whether a checkpoint
|
||||
# version of the model was previously defined, and whether the current
|
||||
# model is a diffusers (indicated with a path)
|
||||
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
|
||||
offer_to_delete_weights(model, conf[model], opt.yes_to_all)
|
||||
|
||||
stanza = {}
|
||||
mod = Datasets[model]
|
||||
stanza["description"] = mod["description"]
|
||||
stanza["repo_id"] = mod["repo_id"]
|
||||
stanza["format"] = mod["format"]
|
||||
# diffusers don't need width and height (probably .ckpt doesn't either)
|
||||
# so we no longer require these in INITIAL_MODELS.yaml
|
||||
if "width" in mod:
|
||||
stanza["width"] = mod["width"]
|
||||
if "height" in mod:
|
||||
stanza["height"] = mod["height"]
|
||||
if "file" in mod:
|
||||
stanza["weights"] = os.path.relpath(
|
||||
successfully_downloaded[model], start=Globals.root
|
||||
)
|
||||
stanza["config"] = os.path.normpath(os.path.join(SD_Configs, mod["config"]))
|
||||
if "vae" in mod:
|
||||
if "file" in mod["vae"]:
|
||||
stanza["vae"] = os.path.normpath(
|
||||
os.path.join(Model_dir, Weights_dir, mod["vae"]["file"])
|
||||
)
|
||||
else:
|
||||
stanza["vae"] = mod["vae"]
|
||||
if mod.get("default", False):
|
||||
stanza["default"] = True
|
||||
default_selected = True
|
||||
|
||||
conf[model] = stanza
|
||||
|
||||
# if no default model was chosen, then we select the first
|
||||
# one in the list
|
||||
if not default_selected:
|
||||
conf[list(successfully_downloaded.keys())[0]]["default"] = True
|
||||
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def offer_to_delete_weights(model_name: str, conf_stanza: dict, yes_to_all: bool):
|
||||
if not (weights := conf_stanza.get("weights")):
|
||||
return
|
||||
if re.match("/VAE/", conf_stanza.get("config")):
|
||||
return
|
||||
if yes_to_all or yes_or_no(
|
||||
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Delete the original file {weights}?",
|
||||
default_yes=False,
|
||||
):
|
||||
weights = Path(weights)
|
||||
if not weights.is_absolute():
|
||||
weights = Path(Globals.root) / weights
|
||||
try:
|
||||
weights.unlink()
|
||||
except OSError as e:
|
||||
print(str(e))
|
||||
|
@ -68,3 +68,30 @@ class MultiSelectColumns(npyscreen.MultiSelect):
|
||||
|
||||
def h_cursor_line_right(self,ch):
|
||||
super().h_cursor_line_down(ch)
|
||||
|
||||
class TextBox(npyscreen.MultiLineEdit):
|
||||
def update(self, clear=True):
|
||||
if clear: self.clear()
|
||||
|
||||
HEIGHT = self.height
|
||||
WIDTH = self.width
|
||||
# draw box.
|
||||
self.parent.curses_pad.hline(self.rely, self.relx, curses.ACS_HLINE, WIDTH)
|
||||
self.parent.curses_pad.hline(self.rely + HEIGHT, self.relx, curses.ACS_HLINE, WIDTH)
|
||||
self.parent.curses_pad.vline(self.rely, self.relx, curses.ACS_VLINE, self.height)
|
||||
self.parent.curses_pad.vline(self.rely, self.relx+WIDTH, curses.ACS_VLINE, HEIGHT)
|
||||
|
||||
# draw corners
|
||||
self.parent.curses_pad.addch(self.rely, self.relx, curses.ACS_ULCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely, self.relx+WIDTH, curses.ACS_URCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely+HEIGHT, self.relx, curses.ACS_LLCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely+HEIGHT, self.relx+WIDTH, curses.ACS_LRCORNER, )
|
||||
|
||||
# fool our superclass into thinking drawing area is smaller - this is really hacky but it seems to work
|
||||
(relx,rely,height,width) = (self.relx, self.rely, self.height, self.width)
|
||||
self.relx += 1
|
||||
self.rely += 1
|
||||
self.height -= 1
|
||||
self.width -= 1
|
||||
super().update(clear=False)
|
||||
(self.relx,self.rely,self.height,self.width) = (relx, rely, height, width)
|
||||
|
Loading…
Reference in New Issue
Block a user