From d38e7170fee17f9362e92877407dfa2cf0108882 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 13 Feb 2023 22:14:24 -0500 Subject: [PATCH] fix broken !import_model downloads 1. Now works with sites that produce lots of redirects, such as CIVITAI 2. Derive name of destination model file from HTTP Content-Disposition header, if present. 3. Swap \\ for / in file paths provided by users, to hopefully fix issues with Windows. --- ldm/invoke/CLI.py | 66 +++++----- ldm/invoke/model_manager.py | 25 ++-- ldm/util.py | 238 ++++++++++++++++++++++++------------ 3 files changed, 210 insertions(+), 119 deletions(-) diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index 8c50cbd23e..f72f6058aa 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -1,29 +1,31 @@ -import click import os import re -import sys import shlex +import sys import traceback - from argparse import Namespace from pathlib import Path -from typing import Optional, Union, List +from typing import List, Optional, Union + +import click if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" -from ldm.invoke.globals import Globals +import pyparsing # type: ignore + +import ldm.invoke from ldm.generate import Generate -from ldm.invoke.prompt_parser import PromptParser -from ldm.invoke.readline import get_completer, Completer -from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png -from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata +from ldm.invoke.args import (Args, dream_cmd_from_png, metadata_dumps, + metadata_from_png) +from ldm.invoke.globals import Globals from ldm.invoke.image_util import make_grid from ldm.invoke.log import write_log from ldm.invoke.model_manager import ModelManager - -import ldm.invoke -import pyparsing # type: ignore +from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata +from ldm.invoke.prompt_parser import PromptParser +from ldm.invoke.readline import Completer, get_completer +from ldm.util import url_attachment_name # global used in multiple functions (fix) infile = None @@ -66,11 +68,11 @@ def main(): print(f'>> InvokeAI runtime directory is "{Globals.root}"') # loading here to avoid long delays on startup - from ldm.generate import Generate - # these two lines prevent a horrible warning message from appearing # when the frozen CLIP tokenizer is imported import transformers # type: ignore + + from ldm.generate import Generate transformers.logging.set_verbosity_error() import diffusers diffusers.logging.set_verbosity_error() @@ -574,10 +576,12 @@ def set_default_output_dir(opt:Args, completer:Completer): def import_model(model_path: str, gen, opt, completer): - ''' - model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or - (3) a huggingface repository id - ''' + """ + model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; + (3) a huggingface repository id; or (4) a local directory containing a + diffusers model. + """ + model.path = model_path.replace('\\','/') # windows model_name = None if model_path.startswith(('http:','https:','ftp:')): @@ -653,7 +657,10 @@ def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]: print() return model_names -def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> Optional[str]: +def import_diffuser_model( + path_or_repo: Union[Path, str], gen, _, completer +) -> Optional[str]: + path_or_repo = path_or_repo.replace('\\','/') # windows manager = gen.model_manager default_name = Path(path_or_repo).stem default_description = f'Imported model {default_name}' @@ -676,17 +683,16 @@ def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> return None return model_name -def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]: +def import_ckpt_model( + path_or_url: Union[Path, str], gen, opt, completer +) -> Optional[str]: + path_or_url = path_or_url.replace('\\','/') manager = gen.model_manager - - if not (config_file := _ask_for_config_file(path_or_url, completer)): - return - if config_file.stem == 'v2-inference-v': - print('** InvokeAI cannot run SD 2.X checkpoints directly. Model will be converted into diffusers format') - return optimize_model(path_or_url, gen, opt, completer, config_file) + is_a_url = str(path_or_url).startswith(('http:','https:')) + base_name = Path(url_attachment_name(path_or_url)).name if is_a_url else Path(path_or_url).name + default_name = Path(base_name).stem + default_description = f"Imported model {default_name}" - default_name = Path(path_or_url).stem - default_description = f'Imported model {default_name}' model_name, model_description = _get_model_name_and_desc( manager, completer, @@ -776,7 +782,8 @@ def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=Fa return config_path -def optimize_model(model_name_or_path:str, gen, opt, completer, original_config_file: Path=None): +def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer): + model_name_or_path = model_name_or_path.replace('\\','/') # windows manager = gen.model_manager ckpt_path = None @@ -1067,6 +1074,7 @@ def get_next_command(infile=None, model_name='no model') -> str: # command stri def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan): print('\n* --web was specified, starting web server...') from invokeai.backend import InvokeAIWebServer + # Change working directory to the stable-diffusion directory os.chdir( os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index 3135931eea..76e6fdaa03 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -36,8 +36,8 @@ from ldm.invoke.generator.diffusers_pipeline import \ StableDiffusionGeneratorPipeline from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir, global_models_dir) -from ldm.util import (ask_user, download_with_progress_bar, - instantiate_from_config) +from ldm.util import (ask_user, download_with_resume, + url_attachment_name, instantiate_from_config) DEFAULT_MAX_MODELS = 2 VAE_TO_REPO_ID = { # hack, see note in convert_and_import() @@ -670,15 +670,18 @@ class ModelManager(object): path to the configuration file, then the new entry will be committed to the models.yaml file. """ + if str(weights).startswith(("http:", "https:")): + model_name = model_name or url_attachment_name(weights) + weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1") - config_path = self._resolve_path(config, "configs/stable-diffusion") + config_path = self._resolve_path(config, "configs/stable-diffusion") if weights_path is None or not weights_path.exists(): return False if config_path is None or not config_path.exists(): return False - model_name = model_name or Path(weights).stem + model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header model_description = ( model_description or f"imported stable diffusion weights file {model_name}" ) @@ -748,7 +751,6 @@ class ModelManager(object): into models.yaml. """ new_config = None - import transformers from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser @@ -967,16 +969,15 @@ class ModelManager(object): print("** Migration is done. Continuing...") def _resolve_path( - self, source: Union[str, Path], dest_directory: str + self, source: Union[str, Path], dest_directory: str ) -> Optional[Path]: resolved_path = None if str(source).startswith(("http:", "https:", "ftp:")): - basename = os.path.basename(source) - if not os.path.isabs(dest_directory): - dest_directory = os.path.join(Globals.root, dest_directory) - dest = os.path.join(dest_directory, basename) - if download_with_progress_bar(str(source), Path(dest)): - resolved_path = Path(dest) + dest_directory = Path(dest_directory) + if not dest_directory.is_absolute(): + dest_directory = Globals.root / dest_directory + dest_directory.mkdir(parents=True, exist_ok=True) + resolved_path = download_with_resume(str(source), dest_directory) else: if not os.path.isabs(source): source = os.path.join(Globals.root, source) diff --git a/ldm/util.py b/ldm/util.py index 447875537f..d6d2c9e170 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -1,20 +1,21 @@ import importlib import math import multiprocessing as mp +import os +import re from collections import abc from inspect import isfunction +from pathlib import Path from queue import Queue from threading import Thread -from urllib import request -from tqdm import tqdm -from pathlib import Path -from ldm.invoke.devices import torch_dtype import numpy as np +import requests import torch -import os -import traceback from PIL import Image, ImageDraw, ImageFont +from tqdm import tqdm + +from ldm.invoke.devices import torch_dtype def log_txt_as_img(wh, xc, size=10): @@ -23,18 +24,18 @@ def log_txt_as_img(wh, xc, size=10): b = len(xc) txts = list() for bi in range(b): - txt = Image.new('RGB', wh, color='white') + txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) font = ImageFont.load_default() nc = int(40 * (wh[0] / 256)) - lines = '\n'.join( + lines = "\n".join( xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) ) try: - draw.text((0, 0), lines, fill='black', font=font) + draw.text((0, 0), lines, fill="black", font=font) except UnicodeEncodeError: - print('Cant encode string for logging. Skipping.') + print("Cant encode string for logging. Skipping.") txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) @@ -77,25 +78,23 @@ def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print( - f' | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.' + f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params." ) return total_params def instantiate_from_config(config, **kwargs): - if not 'target' in config: - if config == '__is_first_stage__': + if not "target" in config: + if config == "__is_first_stage__": return None - elif config == '__is_unconditional__': + elif config == "__is_unconditional__": return None - raise KeyError('Expected key `target` to instantiate.') - return get_obj_from_str(config['target'])( - **config.get('params', dict()), **kwargs - ) + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs) def get_obj_from_str(string, reload=False): - module, cls = string.rsplit('.', 1) + module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) @@ -111,14 +110,14 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): else: res = func(data) Q.put([idx, res]) - Q.put('Done') + Q.put("Done") def parallel_data_prefetch( func: callable, data, n_proc, - target_data_type='ndarray', + target_data_type="ndarray", cpu_intensive=True, use_worker_id=False, ): @@ -126,21 +125,21 @@ def parallel_data_prefetch( # raise ValueError( # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." # ) - if isinstance(data, np.ndarray) and target_data_type == 'list': - raise ValueError('list expected but function got ndarray.') + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") elif isinstance(data, abc.Iterable): if isinstance(data, dict): print( - f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + 'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' ) data = list(data.values()) - if target_data_type == 'ndarray': + if target_data_type == "ndarray": data = np.asarray(data) else: data = list(data) else: raise TypeError( - f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.' + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." ) if cpu_intensive: @@ -150,7 +149,7 @@ def parallel_data_prefetch( Q = Queue(1000) proc = Thread # spawn processes - if target_data_type == 'ndarray': + if target_data_type == "ndarray": arguments = [ [func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc)) @@ -173,7 +172,7 @@ def parallel_data_prefetch( processes += [p] # start processes - print(f'Start prefetching...') + print("Start prefetching...") import time start = time.time() @@ -186,13 +185,13 @@ def parallel_data_prefetch( while k < n_proc: # get result res = Q.get() - if res == 'Done': + if res == "Done": k += 1 else: gather_res[res[0]] = res[1] except Exception as e: - print('Exception: ', e) + print("Exception: ", e) for p in processes: p.terminate() @@ -200,15 +199,15 @@ def parallel_data_prefetch( finally: for p in processes: p.join() - print(f'Prefetching complete. [{time.time() - start} sec.]') + print(f"Prefetching complete. [{time.time() - start} sec.]") - if target_data_type == 'ndarray': + if target_data_type == "ndarray": if not isinstance(gather_res[0], np.ndarray): return np.concatenate([np.asarray(r) for r in gather_res], axis=0) # order outputs return np.concatenate(gather_res, axis=0) - elif target_data_type == 'list': + elif target_data_type == "list": out = [] for r in gather_res: out.extend(r) @@ -216,49 +215,79 @@ def parallel_data_prefetch( else: return gather_res -def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3): + +def rand_perlin_2d( + shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3 +): delta = (res[0] / shape[0], res[1] / shape[1]) d = (shape[0] // res[0], shape[1] // res[1]) - grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1 + grid = ( + torch.stack( + torch.meshgrid( + torch.arange(0, res[0], delta[0]), + torch.arange(0, res[1], delta[1]), + indexing="ij", + ), + dim=-1, + ).to(device) + % 1 + ) - rand_val = torch.rand(res[0]+1, res[1]+1) + rand_val = torch.rand(res[0] + 1, res[1] + 1) - angles = 2*math.pi*rand_val - gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1).to(device) + angles = 2 * math.pi * rand_val + gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device) - tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1) + tile_grads = ( + lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] + .repeat_interleave(d[0], 0) + .repeat_interleave(d[1], 1) + ) - dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1) + dot = lambda grad, shift: ( + torch.stack( + ( + grid[: shape[0], : shape[1], 0] + shift[0], + grid[: shape[0], : shape[1], 1] + shift[1], + ), + dim=-1, + ) + * grad[: shape[0], : shape[1]] + ).sum(dim=-1) - n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device) + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device) n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device) - n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device) - n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device) - t = fade(grid[:shape[0], :shape[1]]) - noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device) + n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device) + n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device) + t = fade(grid[: shape[0], : shape[1]]) + noise = math.sqrt(2) * torch.lerp( + torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1] + ).to(device) return noise.to(dtype=torch_dtype(device)) + def ask_user(question: str, answers: list): from itertools import chain, repeat - user_prompt = f'\n>> {question} {answers}: ' - invalid_answer_msg = 'Invalid answer. Please try again.' - pose_question = chain([user_prompt], repeat('\n'.join([invalid_answer_msg, user_prompt]))) + + user_prompt = f"\n>> {question} {answers}: " + invalid_answer_msg = "Invalid answer. Please try again." + pose_question = chain( + [user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt])) + ) user_answers = map(input, pose_question) valid_response = next(filter(answers.__contains__, user_answers)) return valid_response -def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False ): +def debug_image( + debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False +): if not debug_status: return image_copy = debug_image.copy().convert("RGBA") - ImageDraw.Draw(image_copy).text( - (5, 5), - debug_text, - (255, 0, 0) - ) + ImageDraw.Draw(image_copy).text((5, 5), debug_text, (255, 0, 0)) if debug_show: image_copy.show() @@ -266,31 +295,84 @@ def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, de if debug_result: return image_copy -#------------------------------------- -class ProgressBar(): - def __init__(self,model_name='file'): - self.pbar = None - self.name = model_name - def __call__(self, block_num, block_size, total_size): - if not self.pbar: - self.pbar=tqdm(desc=self.name, - initial=0, - unit='iB', - unit_scale=True, - unit_divisor=1000, - total=total_size) - self.pbar.update(block_size) +# ------------------------------------- +def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path: + ''' + Download a model file. + :param url: https, http or ftp URL + :param dest: A Path object. If path exists and is a directory, then we try to derive the filename + from the URL's Content-Disposition header and copy the URL contents into + dest/filename + :param access_token: Access token to access this resource + ''' + resp = requests.get(url, stream=True) + total = int(resp.headers.get("content-length", 0)) + + if dest.is_dir(): + try: + file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1) + except: + file_name = os.path.basename(url) + dest = dest / file_name + else: + dest.parent.mkdir(parents=True, exist_ok=True) + + print(f'DEBUG: after many manipulations, dest={dest}') + + header = {"Authorization": f"Bearer {access_token}"} if access_token else {} + open_mode = "wb" + exist_size = 0 + + if dest.exists(): + exist_size = dest.stat().st_size + header["Range"] = f"bytes={exist_size}-" + open_mode = "ab" + + if ( + resp.status_code == 416 + ): # "range not satisfiable", which means nothing to return + print(f"* {dest}: complete file found. Skipping.") + return dest + elif resp.status_code != 200: + print(f"** An error occurred during downloading {dest}: {resp.reason}") + elif exist_size > 0: + print(f"* {dest}: partial file found. Resuming...") + else: + print(f"* {dest}: Downloading...") -def download_with_progress_bar(url:str, dest:Path)->bool: try: - if not dest.exists(): - dest.parent.mkdir(parents=True, exist_ok=True) - request.urlretrieve(url,dest,ProgressBar(dest.stem)) - return True - else: - return True - except OSError: - print(traceback.format_exc()) - return False + if total < 2000: + print(f"*** ERROR DOWNLOADING {url}: {resp.text}") + return None + with open(dest, open_mode) as file, tqdm( + desc=str(dest), + initial=exist_size, + total=total + exist_size, + unit="iB", + unit_scale=True, + unit_divisor=1000, + ) as bar: + for data in resp.iter_content(chunk_size=1024): + size = file.write(data) + bar.update(size) + except Exception as e: + print(f"An error occurred while downloading {dest}: {str(e)}") + return None + + return dest + + +def url_attachment_name(url: str) -> dict: + try: + resp = requests.get(url, stream=True) + match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")) + return match.group(1) + except: + return None + + +def download_with_progress_bar(url: str, dest: Path) -> bool: + result = download_with_resume(url, dest, access_token=None) + return result is not None