rewrite of widget display - marshalling needs rewrite

This commit is contained in:
Lincoln Stein
2023-06-15 23:32:33 -04:00
parent 5c740452f6
commit ada7399753
7 changed files with 473 additions and 464 deletions

View File

@ -16,6 +16,7 @@ import shutil
import textwrap
import traceback
import warnings
import yaml
from argparse import Namespace
from pathlib import Path
from shutil import get_terminal_size
@ -25,6 +26,7 @@ from urllib import request
import npyscreen
import transformers
from diffusers import AutoencoderKL
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder
from huggingface_hub import login as hf_hub_login
from omegaconf import OmegaConf
@ -34,6 +36,8 @@ from transformers import (
CLIPSegForImageSegmentation,
CLIPTextModel,
CLIPTokenizer,
AutoFeatureExtractor,
BertTokenizerFast,
)
import invokeai.configs as configs
@ -58,6 +62,9 @@ from invokeai.backend.install.model_install_backend import (
recommended_datasets,
UserSelections,
)
from invokeai.backend.model_management.model_probe import (
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType
)
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
@ -81,7 +88,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# or renaming it and then running invokeai-configure again.
"""
logger=None
logger=InvokeAILogger.getLogger()
# --------------------------------------------
def postscript(errors: None):
@ -162,75 +169,91 @@ class ProgressBar:
# ---------------------------------------------
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
try:
print(f"Installing {label} model file {model_url}...", end="", file=sys.stderr)
logger.info(f"Installing {label} model file {model_url}...")
if not os.path.exists(model_dest):
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
request.urlretrieve(
model_url, model_dest, ProgressBar(os.path.basename(model_dest))
)
print("...downloaded successfully", file=sys.stderr)
logger.info("...downloaded successfully")
else:
print("...exists", file=sys.stderr)
logger.info("...exists")
except Exception:
print("...download failed", file=sys.stderr)
print(f"Error downloading {label} model", file=sys.stderr)
logger.info("...download failed")
logger.info(f"Error downloading {label} model")
print(traceback.format_exc(), file=sys.stderr)
# ---------------------------------------------
# this will preload the Bert tokenizer fles
def download_bert():
print("Installing bert tokenizer...", file=sys.stderr)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
from transformers import BertTokenizerFast
def download_conversion_models():
target_dir = config.root_path / 'models/core/convert'
kwargs = dict() # for future use
try:
logger.info('Downloading core tokenizers and text encoders')
download_from_hf(BertTokenizerFast, "bert-base-uncased")
# bert
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
# sd-1
repo_id = 'openai/clip-vit-large-patch14'
download_from_hf(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14')
download_from_hf(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14')
# sd-2
repo_id = "stabilityai/stable-diffusion-2"
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
# ---------------------------------------------
def download_sd1_clip():
print("Installing SD1 clip model...", file=sys.stderr)
version = "openai/clip-vit-large-patch14"
download_from_hf(CLIPTokenizer, version)
download_from_hf(CLIPTextModel, version)
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
# VAE
logger.info('Downloading stable diffusion VAE')
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
# ---------------------------------------------
def download_sd2_clip():
version = "stabilityai/stable-diffusion-2"
print("Installing SD2 clip model...", file=sys.stderr)
download_from_hf(CLIPTokenizer, version, subfolder="tokenizer")
download_from_hf(CLIPTextModel, version, subfolder="text_encoder")
# safety checking
logger.info('Downloading safety checker')
repo_id = "CompVis/stable-diffusion-safety-checker"
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
# ---------------------------------------------
def download_realesrgan():
print("Installing models from RealESRGAN...", file=sys.stderr)
logger.info("Installing models from RealESRGAN...")
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
model_dest = config.root_path / "models/realesrgan/realesr-general-x4v3.pth"
wdn_model_dest = config.root_path / "models/realesrgan/realesr-general-wdn-x4v3.pth"
model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-x4v3.pth"
wdn_model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
def download_gfpgan():
print("Installing GFPGAN models...", file=sys.stderr)
logger.info("Installing GFPGAN models...")
for model in (
[
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
"./models/gfpgan/GFPGANv1.4.pth",
"./models/core/face_restoration/gfpgan/GFPGANv1.4.pth",
],
[
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
"./models/gfpgan/weights/detection_Resnet50_Final.pth",
"./models/core/face_restoration/gfpgan/weights/detection_Resnet50_Final.pth",
],
[
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
"./models/gfpgan/weights/parsing_parsenet.pth",
"./models/core/face_restoration/gfpgan/weights/parsing_parsenet.pth",
],
):
model_url, model_dest = model[0], config.root_path / model[1]
@ -239,70 +262,32 @@ def download_gfpgan():
# ---------------------------------------------
def download_codeformer():
print("Installing CodeFormer model file...", file=sys.stderr)
logger.info("Installing CodeFormer model file...")
model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
model_dest = config.root_path / "models/codeformer/codeformer.pth"
model_dest = config.root_path / "models/core/face_restoration/codeformer/codeformer.pth"
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
# ---------------------------------------------
def download_clipseg():
print("Installing clipseg model for text-based masking...", file=sys.stderr)
logger.info("Installing clipseg model for text-based masking...")
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
try:
download_from_hf(AutoProcessor, CLIPSEG_MODEL)
download_from_hf(CLIPSegForImageSegmentation, CLIPSEG_MODEL)
download_from_hf(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
download_from_hf(CLIPSegForImageSegmentation, CLIPSEG_MODEL,'models/core/misc/clipseg')
except Exception:
print("Error installing clipseg model:")
print(traceback.format_exc())
logger.info("Error installing clipseg model:")
logger.info(traceback.format_exc())
# -------------------------------------
def download_safety_checker():
print("Installing model for NSFW content detection...", file=sys.stderr)
try:
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from transformers import AutoFeatureExtractor
except ModuleNotFoundError:
print("Error installing NSFW checker model:")
print(traceback.format_exc())
return
safety_model_id = "CompVis/stable-diffusion-safety-checker"
print("AutoFeatureExtractor...", file=sys.stderr)
download_from_hf(AutoFeatureExtractor, safety_model_id)
print("StableDiffusionSafetyChecker...", file=sys.stderr)
download_from_hf(StableDiffusionSafetyChecker, safety_model_id)
# -------------------------------------
def download_vaes():
print("Installing stabilityai VAE...", file=sys.stderr)
try:
# first the diffusers version
repo_id = "stabilityai/sd-vae-ft-mse"
args = dict(
cache_dir=config.cache_dir,
)
if not AutoencoderKL.from_pretrained(repo_id, **args):
raise Exception(f"download of {repo_id} failed")
repo_id = "stabilityai/sd-vae-ft-mse-original"
model_name = "vae-ft-mse-840000-ema-pruned.ckpt"
# next the legacy checkpoint version
if not hf_download_with_resume(
repo_id=repo_id,
model_name=model_name,
model_dir=str(config.root_path / Model_dir / Weights_dir),
):
raise Exception(f"download of {model_name} failed")
except Exception as e:
print(f"Error downloading StabilityAI standard VAE: {str(e)}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def download_support_models():
download_realesrgan()
download_gfpgan()
download_codeformer()
download_clipseg()
download_conversion_models()
# -------------------------------------
def get_root(root: str = None) -> str:
@ -657,17 +642,13 @@ def default_user_selections(program_opts: Namespace) -> UserSelections:
# -------------------------------------
def initialize_rootdir(root: Path, yes_to_all: bool = False):
print("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
for name in (
"models",
"configs",
"embeddings",
"databases",
"loras",
"controlnets",
"text-inversion-output",
"text-inversion-training-data",
"configs"
):
os.makedirs(os.path.join(root, name), exist_ok=True)
@ -676,6 +657,22 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
if not os.path.samefile(configs_src, configs_dest):
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
dest = root / 'models'
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]:
for model_type in [ModelType.Pipeline, ModelType.Vae, ModelType.Lora,
ModelType.ControlNet,ModelType.TextualInversion]:
path = dest / model_base.value / model_type.value
path.mkdir(parents=True, exist_ok=True)
path = dest / 'core'
path.mkdir(parents=True, exist_ok=True)
with open(root / 'configs' / 'models.yaml','w') as yaml_file:
yaml_file.write(yaml.dump({'__metadata__':
{'version':'3.0.0'}
}
)
)
# -------------------------------------
def run_console_ui(
@ -837,7 +834,7 @@ def main():
old_init_file = config.root_path / 'invokeai.init'
new_init_file = config.root_path / 'invokeai.yaml'
if old_init_file.exists() and not new_init_file.exists():
print('** Migrating invokeai.init to invokeai.yaml')
logger.info('** Migrating invokeai.init to invokeai.yaml')
migrate_init_file(old_init_file)
# Load new init file into config
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
@ -855,29 +852,21 @@ def main():
if init_options:
write_opts(init_options, new_init_file)
else:
print(
logger.info(
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
)
sys.exit(0)
if opt.skip_support_models:
print("\n** SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST **")
logger.info("SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST")
else:
print("\n** CHECKING/UPDATING SUPPORT MODELS **")
download_bert()
download_sd1_clip()
download_sd2_clip()
download_realesrgan()
download_gfpgan()
download_codeformer()
download_clipseg()
download_safety_checker()
download_vaes()
logger.info("CHECKING/UPDATING SUPPORT MODELS")
download_support_models()
if opt.skip_sd_weights:
print("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
logger.info("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
elif models_to_download:
print("\n** DOWNLOADING DIFFUSION WEIGHTS **")
logger.info("\n** DOWNLOADING DIFFUSION WEIGHTS **")
process_and_execute(opt, models_to_download)
postscript(errors=errors)

View File

@ -9,7 +9,7 @@ import warnings
from dataclasses import dataclass,field
from pathlib import Path
from tempfile import TemporaryFile
from typing import List, Dict, Callable
from typing import List, Dict, Set, Callable
import requests
from diffusers import AutoencoderKL
@ -20,8 +20,8 @@ from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util.logging import InvokeAILogger
@ -62,7 +62,6 @@ class ModelInstallList:
class UserSelections():
install_models: List[str]= field(default_factory=list)
remove_models: List[str]=field(default_factory=list)
purge_deleted_models: bool=field(default_factory=list)
install_cn_models: List[str] = field(default_factory=list)
remove_cn_models: List[str] = field(default_factory=list)
install_lora_models: List[str] = field(default_factory=list)
@ -72,6 +71,64 @@ class UserSelections():
scan_directory: Path = None
autoscan_on_startup: bool=False
import_model_paths: str=None
@dataclass
class ModelLoadInfo():
name: str
model_type: ModelType
base_type: BaseModelType
path: Path = None
repo_id: str = None
description: str = ''
installed: bool = False
recommended: bool = False
class ModelInstall(object):
def __init__(self,config:InvokeAIAppConfig):
self.config = config
self.mgr = ModelManager(config.model_conf_path)
self.datasets = OmegaConf.load(Dataset_path)
def all_models(self)->Dict[str,ModelLoadInfo]:
'''
Return dict of model_key=>ModelStatus
'''
model_dict = dict()
# first populate with the entries in INITIAL_MODELS.yaml
for key, value in self.datasets.items():
name,base,model_type = ModelManager.parse_key(key)
value['name'] = name
value['base_type'] = base
value['model_type'] = model_type
model_dict[key] = ModelLoadInfo(**value)
# supplement with entries in models.yaml
installed_models = self.mgr.list_models()
for base in installed_models.keys():
for model_type in installed_models[base].keys():
for name, value in installed_models[base][model_type].items():
key = ModelManager.create_key(name, base, model_type)
if key in model_dict:
model_dict[key].installed = True
else:
model_dict[key] = ModelLoadInfo(
name = name,
base_type = base,
model_type = model_type,
description = value.get('description'),
path = value.get('path'),
installed = True,
)
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
def starter_models(self)->Set[str]:
models = set()
for key, value in self.datasets.items():
name,base,model_type = ModelManager.parse_key(key)
if model_type==ModelType.Pipeline:
models.add(key)
return models
def default_config_file():
return config.model_conf_path
@ -85,6 +142,15 @@ def initial_models():
return Datasets
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
def add_models(model_manager, config_file_path: Path, models: List[tuple[str,str,str]]):
print(f'Installing {models}')
def del_models(model_manager, config_file_path: Path, models: List[tuple[str,str,str]]):
for base, model_type, name in models:
logger.info(f"Deleting {name}...")
model_manager.del_model(name, base, model_type)
model_manager.commit(config_file_path)
def install_requested_models(
diffusers: ModelInstallList = None,
controlnet: ModelInstallList = None,
@ -95,9 +161,8 @@ def install_requested_models(
external_models: List[str] = None,
scan_at_startup: bool = False,
precision: str = "float16",
purge_deleted: bool = False,
config_file_path: Path = None,
model_config_file_callback: Callable[[Path],Path] = None
model_config_file_callback: Callable[[Path],Path] = None,
):
"""
Entry point for installing/deleting starter models, or installing external models.
@ -110,40 +175,27 @@ def install_requested_models(
# prevent circular import here
from ..model_management import ModelManager
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
if controlnet:
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
model_manager.delete_controlnet_models(controlnet.remove_models)
if lora:
model_manager.install_lora_models(lora.install_models, access_token=access_token)
model_manager.delete_lora_models(lora.remove_models)
for x in [controlnet, lora, ti, diffusers]:
if x:
add_models(model_manager, config_file_path, x.install_models)
del_models(model_manager, config_file_path, x.remove_models)
# if diffusers:
if ti:
model_manager.install_ti_models(ti.install_models, access_token=access_token)
model_manager.delete_ti_models(ti.remove_models)
if diffusers:
# TODO: Replace next three paragraphs with calls into new model manager
if diffusers.remove_models and len(diffusers.remove_models) > 0:
logger.info("Processing requested deletions")
for model in diffusers.remove_models:
logger.info(f"{model}...")
model_manager.del_model(model, delete_files=purge_deleted)
model_manager.commit(config_file_path)
if diffusers.install_models and len(diffusers.install_models) > 0:
logger.info("Installing requested models")
downloaded_paths = download_weight_datasets(
models=diffusers.install_models,
access_token=None,
precision=precision,
)
successful = {x:v for x,v in downloaded_paths.items() if v is not None}
if len(successful) > 0:
update_config_file(successful, config_file_path)
if len(successful) < len(diffusers.install_models):
unsuccessful = [x for x in downloaded_paths if downloaded_paths[x] is None]
logger.warning(f"Some of the model downloads were not successful: {unsuccessful}")
# if diffusers.install_models and len(diffusers.install_models) > 0:
# logger.info("Installing requested models")
# downloaded_paths = download_weight_datasets(
# models=diffusers.install_models,
# access_token=None,
# precision=precision,
# )
# successful = {x:v for x,v in downloaded_paths.items() if v is not None}
# if len(successful) > 0:
# update_config_file(successful, config_file_path)
# if len(successful) < len(diffusers.install_models):
# unsuccessful = [x for x in downloaded_paths if downloaded_paths[x] is None]
# logger.warning(f"Some of the model downloads were not successful: {unsuccessful}")
# due to above, we have to reload the model manager because conf file
# was changed behind its back
@ -156,8 +208,8 @@ def install_requested_models(
if len(external_models) > 0:
logger.info("INSTALLING EXTERNAL MODELS")
for path_url_or_repo in external_models:
logger.debug(path_url_or_repo)
try:
logger.debug(f'In install_requested_models; callback = {model_config_file_callback}')
model_manager.heuristic_import(
path_url_or_repo,
commit_to_conf=config_file_path,
@ -280,21 +332,18 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
# ---------------------------------------------
def download_from_hf(
model_class: object, model_name: str, **kwargs
model_class: object, model_name: str, destination: Path, **kwargs
):
logger = InvokeAILogger.getLogger('InvokeAI')
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
path = config.cache_dir
model = model_class.from_pretrained(
model_name,
cache_dir=path,
resume_download=True,
**kwargs,
)
model_name = "--".join(("models", *model_name.split("/")))
return path / model_name if model else None
model.save_pretrained(destination, safe_serialization=True)
return destination
def _download_diffusion_weights(
mconfig: DictConfig, access_token: str, precision: str = "float32"