mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
rewrite of widget display - marshalling needs rewrite
This commit is contained in:
@ -16,6 +16,7 @@ import shutil
|
||||
import textwrap
|
||||
import traceback
|
||||
import warnings
|
||||
import yaml
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
@ -25,6 +26,7 @@ from urllib import request
|
||||
import npyscreen
|
||||
import transformers
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from huggingface_hub import HfFolder
|
||||
from huggingface_hub import login as hf_hub_login
|
||||
from omegaconf import OmegaConf
|
||||
@ -34,6 +36,8 @@ from transformers import (
|
||||
CLIPSegForImageSegmentation,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
AutoFeatureExtractor,
|
||||
BertTokenizerFast,
|
||||
)
|
||||
import invokeai.configs as configs
|
||||
|
||||
@ -58,6 +62,9 @@ from invokeai.backend.install.model_install_backend import (
|
||||
recommended_datasets,
|
||||
UserSelections,
|
||||
)
|
||||
from invokeai.backend.model_management.model_probe import (
|
||||
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType
|
||||
)
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
@ -81,7 +88,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# or renaming it and then running invokeai-configure again.
|
||||
"""
|
||||
|
||||
logger=None
|
||||
logger=InvokeAILogger.getLogger()
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: None):
|
||||
@ -162,75 +169,91 @@ class ProgressBar:
|
||||
# ---------------------------------------------
|
||||
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
|
||||
try:
|
||||
print(f"Installing {label} model file {model_url}...", end="", file=sys.stderr)
|
||||
logger.info(f"Installing {label} model file {model_url}...")
|
||||
if not os.path.exists(model_dest):
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
request.urlretrieve(
|
||||
model_url, model_dest, ProgressBar(os.path.basename(model_dest))
|
||||
)
|
||||
print("...downloaded successfully", file=sys.stderr)
|
||||
logger.info("...downloaded successfully")
|
||||
else:
|
||||
print("...exists", file=sys.stderr)
|
||||
logger.info("...exists")
|
||||
except Exception:
|
||||
print("...download failed", file=sys.stderr)
|
||||
print(f"Error downloading {label} model", file=sys.stderr)
|
||||
logger.info("...download failed")
|
||||
logger.info(f"Error downloading {label} model")
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
# this will preload the Bert tokenizer fles
|
||||
def download_bert():
|
||||
print("Installing bert tokenizer...", file=sys.stderr)
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
from transformers import BertTokenizerFast
|
||||
def download_conversion_models():
|
||||
target_dir = config.root_path / 'models/core/convert'
|
||||
kwargs = dict() # for future use
|
||||
try:
|
||||
logger.info('Downloading core tokenizers and text encoders')
|
||||
|
||||
download_from_hf(BertTokenizerFast, "bert-base-uncased")
|
||||
# bert
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
||||
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
|
||||
|
||||
# sd-1
|
||||
repo_id = 'openai/clip-vit-large-patch14'
|
||||
download_from_hf(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14')
|
||||
download_from_hf(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14')
|
||||
|
||||
# sd-2
|
||||
repo_id = "stabilityai/stable-diffusion-2"
|
||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_sd1_clip():
|
||||
print("Installing SD1 clip model...", file=sys.stderr)
|
||||
version = "openai/clip-vit-large-patch14"
|
||||
download_from_hf(CLIPTokenizer, version)
|
||||
download_from_hf(CLIPTextModel, version)
|
||||
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
|
||||
|
||||
# VAE
|
||||
logger.info('Downloading stable diffusion VAE')
|
||||
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
|
||||
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_sd2_clip():
|
||||
version = "stabilityai/stable-diffusion-2"
|
||||
print("Installing SD2 clip model...", file=sys.stderr)
|
||||
download_from_hf(CLIPTokenizer, version, subfolder="tokenizer")
|
||||
download_from_hf(CLIPTextModel, version, subfolder="text_encoder")
|
||||
# safety checking
|
||||
logger.info('Downloading safety checker')
|
||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||
|
||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_realesrgan():
|
||||
print("Installing models from RealESRGAN...", file=sys.stderr)
|
||||
logger.info("Installing models from RealESRGAN...")
|
||||
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
|
||||
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
|
||||
|
||||
model_dest = config.root_path / "models/realesrgan/realesr-general-x4v3.pth"
|
||||
wdn_model_dest = config.root_path / "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-x4v3.pth"
|
||||
wdn_model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
|
||||
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
|
||||
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
|
||||
|
||||
|
||||
def download_gfpgan():
|
||||
print("Installing GFPGAN models...", file=sys.stderr)
|
||||
logger.info("Installing GFPGAN models...")
|
||||
for model in (
|
||||
[
|
||||
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
|
||||
"./models/gfpgan/GFPGANv1.4.pth",
|
||||
"./models/core/face_restoration/gfpgan/GFPGANv1.4.pth",
|
||||
],
|
||||
[
|
||||
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
|
||||
"./models/gfpgan/weights/detection_Resnet50_Final.pth",
|
||||
"./models/core/face_restoration/gfpgan/weights/detection_Resnet50_Final.pth",
|
||||
],
|
||||
[
|
||||
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
|
||||
"./models/gfpgan/weights/parsing_parsenet.pth",
|
||||
"./models/core/face_restoration/gfpgan/weights/parsing_parsenet.pth",
|
||||
],
|
||||
):
|
||||
model_url, model_dest = model[0], config.root_path / model[1]
|
||||
@ -239,70 +262,32 @@ def download_gfpgan():
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_codeformer():
|
||||
print("Installing CodeFormer model file...", file=sys.stderr)
|
||||
logger.info("Installing CodeFormer model file...")
|
||||
model_url = (
|
||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||
)
|
||||
model_dest = config.root_path / "models/codeformer/codeformer.pth"
|
||||
model_dest = config.root_path / "models/core/face_restoration/codeformer/codeformer.pth"
|
||||
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_clipseg():
|
||||
print("Installing clipseg model for text-based masking...", file=sys.stderr)
|
||||
logger.info("Installing clipseg model for text-based masking...")
|
||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||
try:
|
||||
download_from_hf(AutoProcessor, CLIPSEG_MODEL)
|
||||
download_from_hf(CLIPSegForImageSegmentation, CLIPSEG_MODEL)
|
||||
download_from_hf(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
|
||||
download_from_hf(CLIPSegForImageSegmentation, CLIPSEG_MODEL,'models/core/misc/clipseg')
|
||||
except Exception:
|
||||
print("Error installing clipseg model:")
|
||||
print(traceback.format_exc())
|
||||
logger.info("Error installing clipseg model:")
|
||||
logger.info(traceback.format_exc())
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def download_safety_checker():
|
||||
print("Installing model for NSFW content detection...", file=sys.stderr)
|
||||
try:
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from transformers import AutoFeatureExtractor
|
||||
except ModuleNotFoundError:
|
||||
print("Error installing NSFW checker model:")
|
||||
print(traceback.format_exc())
|
||||
return
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
print("AutoFeatureExtractor...", file=sys.stderr)
|
||||
download_from_hf(AutoFeatureExtractor, safety_model_id)
|
||||
print("StableDiffusionSafetyChecker...", file=sys.stderr)
|
||||
download_from_hf(StableDiffusionSafetyChecker, safety_model_id)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def download_vaes():
|
||||
print("Installing stabilityai VAE...", file=sys.stderr)
|
||||
try:
|
||||
# first the diffusers version
|
||||
repo_id = "stabilityai/sd-vae-ft-mse"
|
||||
args = dict(
|
||||
cache_dir=config.cache_dir,
|
||||
)
|
||||
if not AutoencoderKL.from_pretrained(repo_id, **args):
|
||||
raise Exception(f"download of {repo_id} failed")
|
||||
|
||||
repo_id = "stabilityai/sd-vae-ft-mse-original"
|
||||
model_name = "vae-ft-mse-840000-ema-pruned.ckpt"
|
||||
# next the legacy checkpoint version
|
||||
if not hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_name=model_name,
|
||||
model_dir=str(config.root_path / Model_dir / Weights_dir),
|
||||
):
|
||||
raise Exception(f"download of {model_name} failed")
|
||||
except Exception as e:
|
||||
print(f"Error downloading StabilityAI standard VAE: {str(e)}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def download_support_models():
|
||||
download_realesrgan()
|
||||
download_gfpgan()
|
||||
download_codeformer()
|
||||
download_clipseg()
|
||||
download_conversion_models()
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
@ -657,17 +642,13 @@ def default_user_selections(program_opts: Namespace) -> UserSelections:
|
||||
|
||||
# -------------------------------------
|
||||
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||
print("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
|
||||
|
||||
logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
|
||||
for name in (
|
||||
"models",
|
||||
"configs",
|
||||
"embeddings",
|
||||
"databases",
|
||||
"loras",
|
||||
"controlnets",
|
||||
"text-inversion-output",
|
||||
"text-inversion-training-data",
|
||||
"configs"
|
||||
):
|
||||
os.makedirs(os.path.join(root, name), exist_ok=True)
|
||||
|
||||
@ -676,6 +657,22 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||
if not os.path.samefile(configs_src, configs_dest):
|
||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||
|
||||
dest = root / 'models'
|
||||
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]:
|
||||
for model_type in [ModelType.Pipeline, ModelType.Vae, ModelType.Lora,
|
||||
ModelType.ControlNet,ModelType.TextualInversion]:
|
||||
path = dest / model_base.value / model_type.value
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
path = dest / 'core'
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(root / 'configs' / 'models.yaml','w') as yaml_file:
|
||||
yaml_file.write(yaml.dump({'__metadata__':
|
||||
{'version':'3.0.0'}
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def run_console_ui(
|
||||
@ -837,7 +834,7 @@ def main():
|
||||
old_init_file = config.root_path / 'invokeai.init'
|
||||
new_init_file = config.root_path / 'invokeai.yaml'
|
||||
if old_init_file.exists() and not new_init_file.exists():
|
||||
print('** Migrating invokeai.init to invokeai.yaml')
|
||||
logger.info('** Migrating invokeai.init to invokeai.yaml')
|
||||
migrate_init_file(old_init_file)
|
||||
# Load new init file into config
|
||||
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
|
||||
@ -855,29 +852,21 @@ def main():
|
||||
if init_options:
|
||||
write_opts(init_options, new_init_file)
|
||||
else:
|
||||
print(
|
||||
logger.info(
|
||||
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
if opt.skip_support_models:
|
||||
print("\n** SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST **")
|
||||
logger.info("SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST")
|
||||
else:
|
||||
print("\n** CHECKING/UPDATING SUPPORT MODELS **")
|
||||
download_bert()
|
||||
download_sd1_clip()
|
||||
download_sd2_clip()
|
||||
download_realesrgan()
|
||||
download_gfpgan()
|
||||
download_codeformer()
|
||||
download_clipseg()
|
||||
download_safety_checker()
|
||||
download_vaes()
|
||||
logger.info("CHECKING/UPDATING SUPPORT MODELS")
|
||||
download_support_models()
|
||||
|
||||
if opt.skip_sd_weights:
|
||||
print("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
|
||||
logger.info("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
|
||||
elif models_to_download:
|
||||
print("\n** DOWNLOADING DIFFUSION WEIGHTS **")
|
||||
logger.info("\n** DOWNLOADING DIFFUSION WEIGHTS **")
|
||||
process_and_execute(opt, models_to_download)
|
||||
|
||||
postscript(errors=errors)
|
||||
|
@ -9,7 +9,7 @@ import warnings
|
||||
from dataclasses import dataclass,field
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryFile
|
||||
from typing import List, Dict, Callable
|
||||
from typing import List, Dict, Set, Callable
|
||||
|
||||
import requests
|
||||
from diffusers import AutoencoderKL
|
||||
@ -20,8 +20,8 @@ from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from ..util.logging import InvokeAILogger
|
||||
|
||||
@ -62,7 +62,6 @@ class ModelInstallList:
|
||||
class UserSelections():
|
||||
install_models: List[str]= field(default_factory=list)
|
||||
remove_models: List[str]=field(default_factory=list)
|
||||
purge_deleted_models: bool=field(default_factory=list)
|
||||
install_cn_models: List[str] = field(default_factory=list)
|
||||
remove_cn_models: List[str] = field(default_factory=list)
|
||||
install_lora_models: List[str] = field(default_factory=list)
|
||||
@ -72,6 +71,64 @@ class UserSelections():
|
||||
scan_directory: Path = None
|
||||
autoscan_on_startup: bool=False
|
||||
import_model_paths: str=None
|
||||
|
||||
@dataclass
|
||||
class ModelLoadInfo():
|
||||
name: str
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
path: Path = None
|
||||
repo_id: str = None
|
||||
description: str = ''
|
||||
installed: bool = False
|
||||
recommended: bool = False
|
||||
|
||||
class ModelInstall(object):
|
||||
def __init__(self,config:InvokeAIAppConfig):
|
||||
self.config = config
|
||||
self.mgr = ModelManager(config.model_conf_path)
|
||||
self.datasets = OmegaConf.load(Dataset_path)
|
||||
|
||||
def all_models(self)->Dict[str,ModelLoadInfo]:
|
||||
'''
|
||||
Return dict of model_key=>ModelStatus
|
||||
'''
|
||||
model_dict = dict()
|
||||
# first populate with the entries in INITIAL_MODELS.yaml
|
||||
for key, value in self.datasets.items():
|
||||
name,base,model_type = ModelManager.parse_key(key)
|
||||
value['name'] = name
|
||||
value['base_type'] = base
|
||||
value['model_type'] = model_type
|
||||
model_dict[key] = ModelLoadInfo(**value)
|
||||
|
||||
# supplement with entries in models.yaml
|
||||
installed_models = self.mgr.list_models()
|
||||
for base in installed_models.keys():
|
||||
for model_type in installed_models[base].keys():
|
||||
for name, value in installed_models[base][model_type].items():
|
||||
key = ModelManager.create_key(name, base, model_type)
|
||||
if key in model_dict:
|
||||
model_dict[key].installed = True
|
||||
else:
|
||||
model_dict[key] = ModelLoadInfo(
|
||||
name = name,
|
||||
base_type = base,
|
||||
model_type = model_type,
|
||||
description = value.get('description'),
|
||||
path = value.get('path'),
|
||||
installed = True,
|
||||
)
|
||||
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
|
||||
|
||||
def starter_models(self)->Set[str]:
|
||||
models = set()
|
||||
for key, value in self.datasets.items():
|
||||
name,base,model_type = ModelManager.parse_key(key)
|
||||
if model_type==ModelType.Pipeline:
|
||||
models.add(key)
|
||||
return models
|
||||
|
||||
|
||||
def default_config_file():
|
||||
return config.model_conf_path
|
||||
@ -85,6 +142,15 @@ def initial_models():
|
||||
return Datasets
|
||||
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
|
||||
|
||||
def add_models(model_manager, config_file_path: Path, models: List[tuple[str,str,str]]):
|
||||
print(f'Installing {models}')
|
||||
|
||||
def del_models(model_manager, config_file_path: Path, models: List[tuple[str,str,str]]):
|
||||
for base, model_type, name in models:
|
||||
logger.info(f"Deleting {name}...")
|
||||
model_manager.del_model(name, base, model_type)
|
||||
model_manager.commit(config_file_path)
|
||||
|
||||
def install_requested_models(
|
||||
diffusers: ModelInstallList = None,
|
||||
controlnet: ModelInstallList = None,
|
||||
@ -95,9 +161,8 @@ def install_requested_models(
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
model_config_file_callback: Callable[[Path],Path] = None
|
||||
model_config_file_callback: Callable[[Path],Path] = None,
|
||||
):
|
||||
"""
|
||||
Entry point for installing/deleting starter models, or installing external models.
|
||||
@ -110,40 +175,27 @@ def install_requested_models(
|
||||
# prevent circular import here
|
||||
from ..model_management import ModelManager
|
||||
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
||||
if controlnet:
|
||||
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
|
||||
model_manager.delete_controlnet_models(controlnet.remove_models)
|
||||
|
||||
if lora:
|
||||
model_manager.install_lora_models(lora.install_models, access_token=access_token)
|
||||
model_manager.delete_lora_models(lora.remove_models)
|
||||
for x in [controlnet, lora, ti, diffusers]:
|
||||
if x:
|
||||
add_models(model_manager, config_file_path, x.install_models)
|
||||
del_models(model_manager, config_file_path, x.remove_models)
|
||||
|
||||
# if diffusers:
|
||||
|
||||
if ti:
|
||||
model_manager.install_ti_models(ti.install_models, access_token=access_token)
|
||||
model_manager.delete_ti_models(ti.remove_models)
|
||||
|
||||
if diffusers:
|
||||
# TODO: Replace next three paragraphs with calls into new model manager
|
||||
if diffusers.remove_models and len(diffusers.remove_models) > 0:
|
||||
logger.info("Processing requested deletions")
|
||||
for model in diffusers.remove_models:
|
||||
logger.info(f"{model}...")
|
||||
model_manager.del_model(model, delete_files=purge_deleted)
|
||||
model_manager.commit(config_file_path)
|
||||
|
||||
if diffusers.install_models and len(diffusers.install_models) > 0:
|
||||
logger.info("Installing requested models")
|
||||
downloaded_paths = download_weight_datasets(
|
||||
models=diffusers.install_models,
|
||||
access_token=None,
|
||||
precision=precision,
|
||||
)
|
||||
successful = {x:v for x,v in downloaded_paths.items() if v is not None}
|
||||
if len(successful) > 0:
|
||||
update_config_file(successful, config_file_path)
|
||||
if len(successful) < len(diffusers.install_models):
|
||||
unsuccessful = [x for x in downloaded_paths if downloaded_paths[x] is None]
|
||||
logger.warning(f"Some of the model downloads were not successful: {unsuccessful}")
|
||||
# if diffusers.install_models and len(diffusers.install_models) > 0:
|
||||
# logger.info("Installing requested models")
|
||||
# downloaded_paths = download_weight_datasets(
|
||||
# models=diffusers.install_models,
|
||||
# access_token=None,
|
||||
# precision=precision,
|
||||
# )
|
||||
# successful = {x:v for x,v in downloaded_paths.items() if v is not None}
|
||||
# if len(successful) > 0:
|
||||
# update_config_file(successful, config_file_path)
|
||||
# if len(successful) < len(diffusers.install_models):
|
||||
# unsuccessful = [x for x in downloaded_paths if downloaded_paths[x] is None]
|
||||
# logger.warning(f"Some of the model downloads were not successful: {unsuccessful}")
|
||||
|
||||
# due to above, we have to reload the model manager because conf file
|
||||
# was changed behind its back
|
||||
@ -156,8 +208,8 @@ def install_requested_models(
|
||||
if len(external_models) > 0:
|
||||
logger.info("INSTALLING EXTERNAL MODELS")
|
||||
for path_url_or_repo in external_models:
|
||||
logger.debug(path_url_or_repo)
|
||||
try:
|
||||
logger.debug(f'In install_requested_models; callback = {model_config_file_callback}')
|
||||
model_manager.heuristic_import(
|
||||
path_url_or_repo,
|
||||
commit_to_conf=config_file_path,
|
||||
@ -280,21 +332,18 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_from_hf(
|
||||
model_class: object, model_name: str, **kwargs
|
||||
model_class: object, model_name: str, destination: Path, **kwargs
|
||||
):
|
||||
logger = InvokeAILogger.getLogger('InvokeAI')
|
||||
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
|
||||
|
||||
path = config.cache_dir
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=path,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model_name = "--".join(("models", *model_name.split("/")))
|
||||
return path / model_name if model else None
|
||||
|
||||
model.save_pretrained(destination, safe_serialization=True)
|
||||
return destination
|
||||
|
||||
def _download_diffusion_weights(
|
||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
||||
|
Reference in New Issue
Block a user