mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add controlnet model downloading
This commit is contained in:
@ -8,11 +8,11 @@ import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryFile
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
import requests
|
||||
from diffusers import AutoencoderKL
|
||||
from huggingface_hub import hf_hub_url
|
||||
from huggingface_hub import hf_hub_url, HfFolder
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from tqdm import tqdm
|
||||
@ -49,7 +49,6 @@ Config_preamble = """
|
||||
|
||||
|
||||
def default_config_file():
|
||||
print(config.root_dir)
|
||||
return config.model_conf_path
|
||||
|
||||
def sd_configs():
|
||||
@ -62,23 +61,35 @@ def initial_models():
|
||||
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
|
||||
|
||||
def install_requested_models(
|
||||
install_initial_models: List[str] = None,
|
||||
remove_models: List[str] = None,
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
install_initial_models: List[str] = None,
|
||||
remove_models: List[str] = None,
|
||||
install_cn_models: List[str] = None,
|
||||
remove_cn_models: List[str] = None,
|
||||
cn_model_map: Dict[str,str] = None,
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
):
|
||||
"""
|
||||
Entry point for installing/deleting starter models, or installing external models.
|
||||
"""
|
||||
access_token = HfFolder.get_token()
|
||||
config_file_path = config_file_path or default_config_file()
|
||||
if not config_file_path.exists():
|
||||
open(config_file_path, "w")
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(config_file_path)['diffusers'], precision=precision)
|
||||
install_controlnet_models(
|
||||
install_cn_models,
|
||||
short_name_map = cn_model_map,
|
||||
precision=precision,
|
||||
access_token=access_token,
|
||||
)
|
||||
delete_controlnet_models(remove_cn_models)
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
||||
|
||||
if remove_models and len(remove_models) > 0:
|
||||
print("== DELETING UNCHECKED STARTER MODELS ==")
|
||||
@ -120,18 +131,20 @@ def install_requested_models(
|
||||
pass
|
||||
|
||||
if scan_at_startup and scan_directory.is_dir():
|
||||
argument = "--autoconvert"
|
||||
print('** The global initfile is no longer supported; rewrite to support new yaml format **')
|
||||
initfile = Path(config.root_dir, 'invokeai.init')
|
||||
replacement = Path(config.root_dir, f"invokeai.init.new")
|
||||
directory = str(scan_directory).replace("\\", "/")
|
||||
with open(initfile, "r") as input:
|
||||
with open(replacement, "w") as output:
|
||||
while line := input.readline():
|
||||
if not line.startswith(argument):
|
||||
output.writelines([line])
|
||||
output.writelines([f"{argument} {directory}"])
|
||||
os.replace(replacement, initfile)
|
||||
update_autoconvert_dir(scan_directory)
|
||||
|
||||
def update_autoconvert_dir(autodir: Path):
|
||||
'''
|
||||
Update the "autoconvert_dir" option in invokeai.yaml
|
||||
'''
|
||||
invokeai_config_path = config.init_file_path
|
||||
conf = OmegaConf.load(invokeai_config_path)
|
||||
conf.InvokeAI.Paths.autoconvert_dir = str(autodir)
|
||||
yaml = OmegaConf.to_yaml(conf)
|
||||
tmpfile = invokeai_config_path.parent / "new_config.tmp"
|
||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(yaml)
|
||||
tmpfile.replace(invokeai_config_path)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
@ -227,6 +240,68 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def install_controlnet_models(
|
||||
short_names: List[str],
|
||||
short_name_map: Dict[str,str],
|
||||
precision: str='float16',
|
||||
access_token: str = None,
|
||||
):
|
||||
'''
|
||||
Download list of controlnet models, using their HuggingFace
|
||||
repo_ids.
|
||||
'''
|
||||
dest_dir = config.controlnet_path
|
||||
if not dest_dir.exists():
|
||||
dest_dir.mkdir(parents=True,exist_ok=False)
|
||||
|
||||
# The model file may be fp32 or fp16, and may be either a
|
||||
# .bin file or a .safetensors. We try each until we get one,
|
||||
# preferring 'fp16' if using half precision, and preferring
|
||||
# safetensors over over bin.
|
||||
precisions = ['.fp16',''] if precision=='float16' else ['']
|
||||
formats = ['.safetensors','.bin']
|
||||
possible_filenames = list()
|
||||
for p in precisions:
|
||||
for f in formats:
|
||||
possible_filenames.append(Path(f'diffusion_pytorch_model{p}{f}'))
|
||||
|
||||
for directory_name in short_names:
|
||||
repo_id = short_name_map[directory_name]
|
||||
safe_name = directory_name.replace('/','--')
|
||||
print(f'Downloading ControlNet model {directory_name} ({repo_id})')
|
||||
hf_download_with_resume(
|
||||
repo_id = repo_id,
|
||||
model_dir = dest_dir / safe_name,
|
||||
model_name = 'config.json',
|
||||
access_token = access_token
|
||||
)
|
||||
|
||||
path = None
|
||||
for filename in possible_filenames:
|
||||
suffix = filename.suffix
|
||||
dest_filename = Path(f'diffusion_pytorch_model{suffix}')
|
||||
print(f'Probing {directory_name}/{filename}...')
|
||||
path = hf_download_with_resume(
|
||||
repo_id = repo_id,
|
||||
model_dir = dest_dir / safe_name,
|
||||
model_name = str(filename),
|
||||
access_token = access_token,
|
||||
model_dest = Path(dest_dir, safe_name, dest_filename),
|
||||
)
|
||||
if path:
|
||||
(path.parent / '.download_complete').touch()
|
||||
break
|
||||
|
||||
# ---------------------------------------------
|
||||
def delete_controlnet_models(short_names: List[str]):
|
||||
for name in short_names:
|
||||
safe_name = name.replace('/','--')
|
||||
directory = config.controlnet_path / safe_name
|
||||
if directory.exists():
|
||||
print(f'Purging controlnet model {name}')
|
||||
shutil.rmtree(str(directory))
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_from_hf(
|
||||
model_class: object, model_name: str, **kwargs
|
||||
@ -273,9 +348,13 @@ def _download_diffusion_weights(
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_with_resume(
|
||||
repo_id: str, model_dir: str, model_name: str, access_token: str = None
|
||||
repo_id: str,
|
||||
model_dir: str,
|
||||
model_name: str,
|
||||
model_dest: Path = None,
|
||||
access_token: str = None,
|
||||
) -> Path:
|
||||
model_dest = Path(os.path.join(model_dir, model_name))
|
||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name)
|
||||
@ -297,18 +376,17 @@ def hf_download_with_resume(
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
print(f"* {model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code == 404:
|
||||
print("** File not found")
|
||||
return None
|
||||
elif resp.status_code != 200:
|
||||
print(f"** An error occurred during downloading {model_name}: {resp.reason}")
|
||||
print(f"** Warning: {model_name}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
print(f"* {model_name}: partial file found. Resuming...")
|
||||
else:
|
||||
print(f"* {model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
if total < 2000:
|
||||
print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}")
|
||||
return None
|
||||
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
|
Reference in New Issue
Block a user