fix broken !import_model downloads

1. Now works with sites that produce lots of redirects, such as CIVITAI
2. Derive name of destination model file from HTTP Content-Disposition header,
   if present.
3. Swap \\ for / in file paths provided by users, to hopefully fix issues with
   Windows.
This commit is contained in:
Lincoln Stein 2023-02-13 22:14:24 -05:00
parent 15a9412255
commit d38e7170fe
3 changed files with 210 additions and 119 deletions

View File

@ -1,29 +1,31 @@
import click
import os
import re
import sys
import shlex
import sys
import traceback
from argparse import Namespace
from pathlib import Path
from typing import Optional, Union, List
from typing import List, Optional, Union
import click
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from ldm.invoke.globals import Globals
import pyparsing # type: ignore
import ldm.invoke
from ldm.generate import Generate
from ldm.invoke.prompt_parser import PromptParser
from ldm.invoke.readline import get_completer, Completer
from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.args import (Args, dream_cmd_from_png, metadata_dumps,
metadata_from_png)
from ldm.invoke.globals import Globals
from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log
from ldm.invoke.model_manager import ModelManager
import ldm.invoke
import pyparsing # type: ignore
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.prompt_parser import PromptParser
from ldm.invoke.readline import Completer, get_completer
from ldm.util import url_attachment_name
# global used in multiple functions (fix)
infile = None
@ -66,11 +68,11 @@ def main():
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# loading here to avoid long delays on startup
from ldm.generate import Generate
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers # type: ignore
from ldm.generate import Generate
transformers.logging.set_verbosity_error()
import diffusers
diffusers.logging.set_verbosity_error()
@ -574,10 +576,12 @@ def set_default_output_dir(opt:Args, completer:Completer):
def import_model(model_path: str, gen, opt, completer):
'''
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or
(3) a huggingface repository id
'''
"""
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
(3) a huggingface repository id; or (4) a local directory containing a
diffusers model.
"""
model.path = model_path.replace('\\','/') # windows
model_name = None
if model_path.startswith(('http:','https:','ftp:')):
@ -653,7 +657,10 @@ def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]:
print()
return model_names
def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> Optional[str]:
def import_diffuser_model(
path_or_repo: Union[Path, str], gen, _, completer
) -> Optional[str]:
path_or_repo = path_or_repo.replace('\\','/') # windows
manager = gen.model_manager
default_name = Path(path_or_repo).stem
default_description = f'Imported model {default_name}'
@ -676,17 +683,16 @@ def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) ->
return None
return model_name
def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]:
def import_ckpt_model(
path_or_url: Union[Path, str], gen, opt, completer
) -> Optional[str]:
path_or_url = path_or_url.replace('\\','/')
manager = gen.model_manager
if not (config_file := _ask_for_config_file(path_or_url, completer)):
return
if config_file.stem == 'v2-inference-v':
print('** InvokeAI cannot run SD 2.X checkpoints directly. Model will be converted into diffusers format')
return optimize_model(path_or_url, gen, opt, completer, config_file)
is_a_url = str(path_or_url).startswith(('http:','https:'))
base_name = Path(url_attachment_name(path_or_url)).name if is_a_url else Path(path_or_url).name
default_name = Path(base_name).stem
default_description = f"Imported model {default_name}"
default_name = Path(path_or_url).stem
default_description = f'Imported model {default_name}'
model_name, model_description = _get_model_name_and_desc(
manager,
completer,
@ -776,7 +782,8 @@ def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=Fa
return config_path
def optimize_model(model_name_or_path:str, gen, opt, completer, original_config_file: Path=None):
def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
model_name_or_path = model_name_or_path.replace('\\','/') # windows
manager = gen.model_manager
ckpt_path = None
@ -1067,6 +1074,7 @@ def get_next_command(infile=None, model_name='no model') -> str: # command stri
def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan):
print('\n* --web was specified, starting web server...')
from invokeai.backend import InvokeAIWebServer
# Change working directory to the stable-diffusion directory
os.chdir(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))

View File

@ -36,8 +36,8 @@ from ldm.invoke.generator.diffusers_pipeline import \
StableDiffusionGeneratorPipeline
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
global_models_dir)
from ldm.util import (ask_user, download_with_progress_bar,
instantiate_from_config)
from ldm.util import (ask_user, download_with_resume,
url_attachment_name, instantiate_from_config)
DEFAULT_MAX_MODELS = 2
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
@ -670,15 +670,18 @@ class ModelManager(object):
path to the configuration file, then the new entry will be committed to the
models.yaml file.
"""
if str(weights).startswith(("http:", "https:")):
model_name = model_name or url_attachment_name(weights)
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
config_path = self._resolve_path(config, "configs/stable-diffusion")
config_path = self._resolve_path(config, "configs/stable-diffusion")
if weights_path is None or not weights_path.exists():
return False
if config_path is None or not config_path.exists():
return False
model_name = model_name or Path(weights).stem
model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header
model_description = (
model_description or f"imported stable diffusion weights file {model_name}"
)
@ -748,7 +751,6 @@ class ModelManager(object):
into models.yaml.
"""
new_config = None
import transformers
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
@ -967,16 +969,15 @@ class ModelManager(object):
print("** Migration is done. Continuing...")
def _resolve_path(
self, source: Union[str, Path], dest_directory: str
self, source: Union[str, Path], dest_directory: str
) -> Optional[Path]:
resolved_path = None
if str(source).startswith(("http:", "https:", "ftp:")):
basename = os.path.basename(source)
if not os.path.isabs(dest_directory):
dest_directory = os.path.join(Globals.root, dest_directory)
dest = os.path.join(dest_directory, basename)
if download_with_progress_bar(str(source), Path(dest)):
resolved_path = Path(dest)
dest_directory = Path(dest_directory)
if not dest_directory.is_absolute():
dest_directory = Globals.root / dest_directory
dest_directory.mkdir(parents=True, exist_ok=True)
resolved_path = download_with_resume(str(source), dest_directory)
else:
if not os.path.isabs(source):
source = os.path.join(Globals.root, source)

View File

@ -1,20 +1,21 @@
import importlib
import math
import multiprocessing as mp
import os
import re
from collections import abc
from inspect import isfunction
from pathlib import Path
from queue import Queue
from threading import Thread
from urllib import request
from tqdm import tqdm
from pathlib import Path
from ldm.invoke.devices import torch_dtype
import numpy as np
import requests
import torch
import os
import traceback
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
from ldm.invoke.devices import torch_dtype
def log_txt_as_img(wh, xc, size=10):
@ -23,18 +24,18 @@ def log_txt_as_img(wh, xc, size=10):
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new('RGB', wh, color='white')
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256))
lines = '\n'.join(
lines = "\n".join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
try:
draw.text((0, 0), lines, fill='black', font=font)
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print('Cant encode string for logging. Skipping.')
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@ -77,25 +78,23 @@ def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
f' | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
)
return total_params
def instantiate_from_config(config, **kwargs):
if not 'target' in config:
if config == '__is_first_stage__':
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == '__is_unconditional__':
elif config == "__is_unconditional__":
return None
raise KeyError('Expected key `target` to instantiate.')
return get_obj_from_str(config['target'])(
**config.get('params', dict()), **kwargs
)
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit('.', 1)
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
@ -111,14 +110,14 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
else:
res = func(data)
Q.put([idx, res])
Q.put('Done')
Q.put("Done")
def parallel_data_prefetch(
func: callable,
data,
n_proc,
target_data_type='ndarray',
target_data_type="ndarray",
cpu_intensive=True,
use_worker_id=False,
):
@ -126,21 +125,21 @@ def parallel_data_prefetch(
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if isinstance(data, np.ndarray) and target_data_type == 'list':
raise ValueError('list expected but function got ndarray.')
if isinstance(data, np.ndarray) and target_data_type == "list":
raise ValueError("list expected but function got ndarray.")
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == 'ndarray':
if target_data_type == "ndarray":
data = np.asarray(data)
else:
data = list(data)
else:
raise TypeError(
f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
)
if cpu_intensive:
@ -150,7 +149,7 @@ def parallel_data_prefetch(
Q = Queue(1000)
proc = Thread
# spawn processes
if target_data_type == 'ndarray':
if target_data_type == "ndarray":
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
@ -173,7 +172,7 @@ def parallel_data_prefetch(
processes += [p]
# start processes
print(f'Start prefetching...')
print("Start prefetching...")
import time
start = time.time()
@ -186,13 +185,13 @@ def parallel_data_prefetch(
while k < n_proc:
# get result
res = Q.get()
if res == 'Done':
if res == "Done":
k += 1
else:
gather_res[res[0]] = res[1]
except Exception as e:
print('Exception: ', e)
print("Exception: ", e)
for p in processes:
p.terminate()
@ -200,15 +199,15 @@ def parallel_data_prefetch(
finally:
for p in processes:
p.join()
print(f'Prefetching complete. [{time.time() - start} sec.]')
print(f"Prefetching complete. [{time.time() - start} sec.]")
if target_data_type == 'ndarray':
if target_data_type == "ndarray":
if not isinstance(gather_res[0], np.ndarray):
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
# order outputs
return np.concatenate(gather_res, axis=0)
elif target_data_type == 'list':
elif target_data_type == "list":
out = []
for r in gather_res:
out.extend(r)
@ -216,49 +215,79 @@ def parallel_data_prefetch(
else:
return gather_res
def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
def rand_perlin_2d(
shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3
):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1
grid = (
torch.stack(
torch.meshgrid(
torch.arange(0, res[0], delta[0]),
torch.arange(0, res[1], delta[1]),
indexing="ij",
),
dim=-1,
).to(device)
% 1
)
rand_val = torch.rand(res[0]+1, res[1]+1)
rand_val = torch.rand(res[0] + 1, res[1] + 1)
angles = 2*math.pi*rand_val
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1).to(device)
angles = 2 * math.pi * rand_val
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
tile_grads = (
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
.repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
dot = lambda grad, shift: (
torch.stack(
(
grid[: shape[0], : shape[1], 0] + shift[0],
grid[: shape[0], : shape[1], 1] + shift[1],
),
dim=-1,
)
* grad[: shape[0], : shape[1]]
).sum(dim=-1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
t = fade(grid[:shape[0], :shape[1]])
noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device)
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device)
t = fade(grid[: shape[0], : shape[1]])
noise = math.sqrt(2) * torch.lerp(
torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
).to(device)
return noise.to(dtype=torch_dtype(device))
def ask_user(question: str, answers: list):
from itertools import chain, repeat
user_prompt = f'\n>> {question} {answers}: '
invalid_answer_msg = 'Invalid answer. Please try again.'
pose_question = chain([user_prompt], repeat('\n'.join([invalid_answer_msg, user_prompt])))
user_prompt = f"\n>> {question} {answers}: "
invalid_answer_msg = "Invalid answer. Please try again."
pose_question = chain(
[user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt]))
)
user_answers = map(input, pose_question)
valid_response = next(filter(answers.__contains__, user_answers))
return valid_response
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False ):
def debug_image(
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
):
if not debug_status:
return
image_copy = debug_image.copy().convert("RGBA")
ImageDraw.Draw(image_copy).text(
(5, 5),
debug_text,
(255, 0, 0)
)
ImageDraw.Draw(image_copy).text((5, 5), debug_text, (255, 0, 0))
if debug_show:
image_copy.show()
@ -266,31 +295,84 @@ def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, de
if debug_result:
return image_copy
#-------------------------------------
class ProgressBar():
def __init__(self,model_name='file'):
self.pbar = None
self.name = model_name
def __call__(self, block_num, block_size, total_size):
if not self.pbar:
self.pbar=tqdm(desc=self.name,
initial=0,
unit='iB',
unit_scale=True,
unit_divisor=1000,
total=total_size)
self.pbar.update(block_size)
# -------------------------------------
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
'''
Download a model file.
:param url: https, http or ftp URL
:param dest: A Path object. If path exists and is a directory, then we try to derive the filename
from the URL's Content-Disposition header and copy the URL contents into
dest/filename
:param access_token: Access token to access this resource
'''
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
if dest.is_dir():
try:
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
except:
file_name = os.path.basename(url)
dest = dest / file_name
else:
dest.parent.mkdir(parents=True, exist_ok=True)
print(f'DEBUG: after many manipulations, dest={dest}')
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
open_mode = "wb"
exist_size = 0
if dest.exists():
exist_size = dest.stat().st_size
header["Range"] = f"bytes={exist_size}-"
open_mode = "ab"
if (
resp.status_code == 416
): # "range not satisfiable", which means nothing to return
print(f"* {dest}: complete file found. Skipping.")
return dest
elif resp.status_code != 200:
print(f"** An error occurred during downloading {dest}: {resp.reason}")
elif exist_size > 0:
print(f"* {dest}: partial file found. Resuming...")
else:
print(f"* {dest}: Downloading...")
def download_with_progress_bar(url:str, dest:Path)->bool:
try:
if not dest.exists():
dest.parent.mkdir(parents=True, exist_ok=True)
request.urlretrieve(url,dest,ProgressBar(dest.stem))
return True
else:
return True
except OSError:
print(traceback.format_exc())
return False
if total < 2000:
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
return None
with open(dest, open_mode) as file, tqdm(
desc=str(dest),
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar:
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
except Exception as e:
print(f"An error occurred while downloading {dest}: {str(e)}")
return None
return dest
def url_attachment_name(url: str) -> dict:
try:
resp = requests.get(url, stream=True)
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
return match.group(1)
except:
return None
def download_with_progress_bar(url: str, dest: Path) -> bool:
result = download_with_resume(url, dest, access_token=None)
return result is not None