diff --git a/docs/installation/050_INSTALLING_MODELS.md b/docs/installation/050_INSTALLING_MODELS.md index a610491798..5621075506 100644 --- a/docs/installation/050_INSTALLING_MODELS.md +++ b/docs/installation/050_INSTALLING_MODELS.md @@ -80,6 +80,13 @@ only `.safetensors` and `.ckpt` models, but they can be easily loaded into InvokeAI and/or converted into optimized `diffusers` models. Be aware that CIVITAI hosts many models that generate NSFW content. +!!! note + + InvokeAI 2.3.x does not support directly importing and + running Stable Diffusion version 2 checkpoint models. You may instead + convert them into `diffusers` models using the conversion methods + described below. + ## Installation There are multiple ways to install and manage models: @@ -90,7 +97,7 @@ There are multiple ways to install and manage models: models files. 3. The web interface (WebUI) has a GUI for importing and managing -models. + models. ### Installation via `invokeai-configure` @@ -106,7 +113,7 @@ confirm that the files are complete. You can install a new model, including any of the community-supported ones, via the command-line client's `!import_model` command. -#### Installing `.ckpt` and `.safetensors` models +#### Installing individual `.ckpt` and `.safetensors` models If the model is already downloaded to your local disk, use `!import_model /path/to/file.ckpt` to load it. For example: @@ -131,15 +138,40 @@ invoke> !import_model https://example.org/sd_models/martians.safetensors For this to work, the URL must not be password-protected. Otherwise you will receive a 404 error. -When you import a legacy model, the CLI will ask you a few questions -about the model, including what size image it was trained on (usually -512x512), what name and description you wish to use for it, what -configuration file to use for it (usually the default -`v1-inference.yaml`), whether you'd like to make this model the -default at startup time, and whether you would like to install a -custom VAE (variable autoencoder) file for the model. For recent -models, the answer to the VAE question is usually "no," but it won't -hurt to answer "yes". +When you import a legacy model, the CLI will first ask you what type +of model this is. You can indicate whether it is a model based on +Stable Diffusion 1.x (1.4 or 1.5), one based on Stable Diffusion 2.x, +or a 1.x inpainting model. Be careful to indicate the correct model +type, or it will not load correctly. You can correct the model type +after the fact using the `!edit_model` command. + +The system will then ask you a few other questions about the model, +including what size image it was trained on (usually 512x512), what +name and description you wish to use for it, and whether you would +like to install a custom VAE (variable autoencoder) file for the +model. For recent models, the answer to the VAE question is usually +"no," but it won't hurt to answer "yes". + +After importing, the model will load. If this is successful, you will +be asked if you want to keep the model loaded in memory to start +generating immediately. You'll also be asked if you wish to make this +the default model on startup. You can change this later using +`!edit_model`. + +#### Importing a batch of `.ckpt` and `.safetensors` models from a directory + +You may also point `!import_model` to a directory containing a set of +`.ckpt` or `.safetensors` files. They will be imported _en masse_. + +!!! example + + ```console + invoke> !import_model C:/Users/fred/Downloads/civitai_models/ + ``` + +You will be given the option to import all models found in the +directory, or select which ones to import. If there are subfolders +within the directory, they will be searched for models to import. #### Installing `diffusers` models @@ -284,14 +316,18 @@ up a dialogue that lists the models you have already installed, and allows you to load, delete or edit them:
+ ![model-manager](../assets/installing-models/webui-models-1.png) +
To add a new model, click on **+ Add New** and select to either a checkpoint/safetensors model, or a diffusers model:
+ ![model-manager-add-new](../assets/installing-models/webui-models-2.png) +
In this example, we chose **Add Diffusers**. As shown in the figure @@ -302,7 +338,9 @@ choose to enter a path to disk, the system will autocomplete for you as you type:
+ ![model-manager-add-diffusers](../assets/installing-models/webui-models-3.png) +
Press **Add Model** at the bottom of the dialogue (scrolled out of @@ -317,7 +355,9 @@ directory and press the "Search" icon. This will display the subfolders, and allow you to choose which ones to import:
+ ![model-manager-add-checkpoint](../assets/installing-models/webui-models-4.png) +
## Model Management Startup Options @@ -342,9 +382,8 @@ invoke.sh --autoconvert /home/fred/stable-diffusion-checkpoints And here is what the same argument looks like in `invokeai.init`: -``` +```bash --outdir="/home/fred/invokeai/outputs --no-nsfw_checker --autoconvert /home/fred/stable-diffusion-checkpoints ``` - diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index 32c6d816be..f72f6058aa 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -1,29 +1,31 @@ import os import re -import sys import shlex +import sys import traceback - from argparse import Namespace from pathlib import Path -from typing import Optional, Union +from typing import List, Optional, Union + +import click if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" -from ldm.invoke.globals import Globals +import pyparsing # type: ignore + +import ldm.invoke from ldm.generate import Generate -from ldm.invoke.prompt_parser import PromptParser -from ldm.invoke.readline import get_completer, Completer -from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png -from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata +from ldm.invoke.args import (Args, dream_cmd_from_png, metadata_dumps, + metadata_from_png) +from ldm.invoke.globals import Globals from ldm.invoke.image_util import make_grid from ldm.invoke.log import write_log from ldm.invoke.model_manager import ModelManager - -import click # type: ignore -import ldm.invoke -import pyparsing # type: ignore +from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata +from ldm.invoke.prompt_parser import PromptParser +from ldm.invoke.readline import Completer, get_completer +from ldm.util import url_attachment_name # global used in multiple functions (fix) infile = None @@ -66,11 +68,11 @@ def main(): print(f'>> InvokeAI runtime directory is "{Globals.root}"') # loading here to avoid long delays on startup - from ldm.generate import Generate - # these two lines prevent a horrible warning message from appearing # when the frozen CLIP tokenizer is imported import transformers # type: ignore + + from ldm.generate import Generate transformers.logging.set_verbosity_error() import diffusers diffusers.logging.set_verbosity_error() @@ -574,10 +576,12 @@ def set_default_output_dir(opt:Args, completer:Completer): def import_model(model_path: str, gen, opt, completer): - ''' - model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or - (3) a huggingface repository id - ''' + """ + model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; + (3) a huggingface repository id; or (4) a local directory containing a + diffusers model. + """ + model.path = model_path.replace('\\','/') # windows model_name = None if model_path.startswith(('http:','https:','ftp:')): @@ -592,12 +596,8 @@ def import_model(model_path: str, gen, opt, completer): models = list(Path(model_path).rglob('*.ckpt')) + list(Path(model_path).rglob('*.safetensors')) if models: - # Only the last model name will be used below. - for model in sorted(models): - - if click.confirm(f'Import {model.stem} ?', default=True): - model_name = import_ckpt_model(model, gen, opt, completer) - print() + models = import_checkpoint_list(models, gen, opt, completer) + model_name = models[0] if len(models) == 1 else None else: model_name = import_diffuser_model(Path(model_path), gen, opt, completer) @@ -614,14 +614,53 @@ def import_model(model_path: str, gen, opt, completer): print('** model failed to load. Discarding configuration entry') gen.model_manager.del_model(model_name) return - if input('Make this the default model? [n] ').strip() in ('y','Y'): + if click.confirm('Make this the default model?', default=False): gen.model_manager.set_default_model(model_name) gen.model_manager.commit(opt.conf) completer.update_models(gen.model_manager.list_models()) print(f'>> {model_name} successfully installed') -def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> Optional[str]: +def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]: + ''' + Does a mass import of all the checkpoint/safetensors on a path list + ''' + model_names = list() + choice = input('** Directory of checkpoint/safetensors models detected. Install ll or elected models? [a] ') or 'a' + do_all = choice.startswith('a') + if do_all: + config_file = _ask_for_config_file(models[0], completer, plural=True) + manager = gen.model_manager + for model in sorted(models): + model_name = f'{model.stem}' + model_description = f'Imported model {model_name}' + if model_name in manager.model_names(): + print(f'** {model_name} is already imported. Skipping.') + elif manager.import_ckpt_model( + model, + config = config_file, + model_name = model_name, + model_description = model_description, + commit_to_conf = opt.conf): + model_names.append(model_name) + print(f'>> Model {model_name} imported successfully') + else: + print(f'** Model {model} failed to import') + else: + for model in sorted(models): + if click.confirm(f'Import {model.stem} ?', default=True): + if model_name := import_ckpt_model(model, gen, opt, completer): + print(f'>> Model {model.stem} imported successfully') + model_names.append(model_name) + else: + printf('** Model {model} failed to import') + print() + return model_names + +def import_diffuser_model( + path_or_repo: Union[Path, str], gen, _, completer +) -> Optional[str]: + path_or_repo = path_or_repo.replace('\\','/') # windows manager = gen.model_manager default_name = Path(path_or_repo).stem default_description = f'Imported model {default_name}' @@ -632,7 +671,7 @@ def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> model_description=default_description ) vae = None - if input('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ').strip() in ('y','Y'): + if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False): vae = dict(repo_id='stabilityai/sd-vae-ft-mse') if not manager.import_diffuser_model( @@ -644,27 +683,22 @@ def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> return None return model_name -def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]: +def import_ckpt_model( + path_or_url: Union[Path, str], gen, opt, completer +) -> Optional[str]: + path_or_url = path_or_url.replace('\\','/') manager = gen.model_manager - default_name = Path(path_or_url).stem - default_description = f'Imported model {default_name}' + is_a_url = str(path_or_url).startswith(('http:','https:')) + base_name = Path(url_attachment_name(path_or_url)).name if is_a_url else Path(path_or_url).name + default_name = Path(base_name).stem + default_description = f"Imported model {default_name}" + model_name, model_description = _get_model_name_and_desc( manager, completer, model_name=default_name, model_description=default_description ) - config_file = None - default = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml') \ - if re.search('inpaint',default_name, flags=re.IGNORECASE) \ - else Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml') - - completer.complete_extensions(('.yaml','.yml')) - completer.set_line(str(default)) - done = False - while not done: - config_file = input('Configuration file for this model: ').strip() - done = os.path.exists(config_file) completer.complete_extensions(('.ckpt','.safetensors')) vae = None @@ -692,10 +726,15 @@ def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Opt def _verify_load(model_name:str, gen)->bool: print('>> Verifying that new model loads...') current_model = gen.model_name - if not gen.model_manager.get_model(model_name): + try: + if not gen.model_manager.get_model(model_name): + return False + except Exception as e: + print(f'** model failed to load: {str(e)}') + print('** note that importing 2.X checkpoints is not supported. Please use !convert_model instead.') return False - do_switch = input('Keep model loaded? [y] ') - if len(do_switch)==0 or do_switch[0] in ('y','Y'): + + if click.confirm('Keep model loaded?', default=True): gen.set_model(model_name) else: print('>> Restoring previous model') @@ -708,16 +747,45 @@ def _get_model_name_and_desc(model_manager,completer,model_name:str='',model_des model_description = input(f'Description for this model [{model_description}]: ').strip() or model_description return model_name, model_description -def _is_inpainting(model_name_or_path: str)->bool: - if re.search('inpaint',model_name_or_path, flags=re.IGNORECASE): - return not input('Is this an inpainting model? [y] ').startswith(('n','N')) - else: - return not input('Is this an inpainting model? [n] ').startswith(('y','Y')) +def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=False)->Path: + default = '1' + if re.search('inpaint',str(model_path),flags=re.IGNORECASE): + default = '3' + choices={ + '1': 'v1-inference.yaml', + '2': 'v2-inference-v.yaml', + '3': 'v1-inpainting-inference.yaml', + } + + prompt = '''What type of models are these?: +[1] Models based on Stable Diffusion 1.X +[2] Models based on Stable Diffusion 2.X +[3] Inpainting models based on Stable Diffusion 1.X +[4] Something else''' if plural else '''What type of model is this?: +[1] A model based on Stable Diffusion 1.X +[2] A model based on Stable Diffusion 2.X +[3] An inpainting models based on Stable Diffusion 1.X +[4] Something else''' + print(prompt) + choice = input(f'Your choice: [{default}] ') + choice = choice.strip() or default + if config_file := choices.get(choice,None): + return Path('configs','stable-diffusion',config_file) -def optimize_model(model_name_or_path: str, gen, opt, completer): + # otherwise ask user to select + done = False + completer.complete_extensions(('.yaml','.yml')) + completer.set_line(str(Path(Globals.root,'configs/stable-diffusion/'))) + while not done: + config_path = input('Configuration file for this model (leave blank to abort): ').strip() + done = not config_path or os.path.exists(config_path) + return config_path + + +def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer): + model_name_or_path = model_name_or_path.replace('\\','/') # windows manager = gen.model_manager ckpt_path = None - original_config_file = None if model_name_or_path == gen.model_name: print("** Can't convert the active model. !switch to another model first. **") @@ -732,6 +800,9 @@ def optimize_model(model_name_or_path: str, gen, opt, completer): print(f'** {model_name_or_path} is not a legacy .ckpt weights file') return elif os.path.exists(model_name_or_path): + original_config_file = original_config_file or _ask_for_config_file(model_name_or_path, completer) + if not original_config_file: + return ckpt_path = Path(model_name_or_path) model_name, model_description = _get_model_name_and_desc( manager, @@ -739,12 +810,6 @@ def optimize_model(model_name_or_path: str, gen, opt, completer): ckpt_path.stem, f'Converted model {ckpt_path.stem}' ) - is_inpainting = _is_inpainting(model_name_or_path) - original_config_file = Path( - 'configs', - 'stable-diffusion', - 'v1-inpainting-inference.yaml' if is_inpainting else 'v1-inference.yaml' - ) else: print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file') return @@ -761,7 +826,7 @@ def optimize_model(model_name_or_path: str, gen, opt, completer): return vae = None - if input('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ').strip() in ('y','Y'): + if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False): vae = dict(repo_id='stabilityai/sd-vae-ft-mse') new_config = gen.model_manager.convert_and_import( @@ -777,11 +842,10 @@ def optimize_model(model_name_or_path: str, gen, opt, completer): return completer.update_models(gen.model_manager.list_models()) - if input(f'Load optimized model {model_name}? [y] ').strip() not in ('n','N'): + if click.confirm(f'Load optimized model {model_name}?', default=True): gen.set_model(model_name) - response = input(f'Delete the original .ckpt file at ({ckpt_path} ? [n] ') - if response.startswith(('y','Y')): + if click.confirm(f'Delete the original .ckpt file at {ckpt_path}?',default=False): ckpt_path.unlink(missing_ok=True) print(f'{ckpt_path} deleted') @@ -794,10 +858,10 @@ def del_config(model_name:str, gen, opt, completer): print(f"** Unknown model {model_name}") return - if input(f'Remove {model_name} from the list of models known to InvokeAI? [y] ').strip().startswith(('n','N')): + if not click.confirm(f'Remove {model_name} from the list of models known to InvokeAI?',default=True): return - delete_completely = input('Completely remove the model file or directory from disk? [n] ').startswith(('y','Y')) + delete_completely = click.confirm('Completely remove the model file or directory from disk?',default=False) gen.model_manager.del_model(model_name,delete_files=delete_completely) gen.model_manager.commit(opt.conf) print(f'** {model_name} deleted') @@ -826,7 +890,7 @@ def edit_model(model_name:str, gen, opt, completer): # this does the update manager.add_model(new_name, info, True) - if input('Make this the default model? [n] ').startswith(('y','Y')): + if click.confirm('Make this the default model?',default=False): manager.set_default_model(new_name) manager.commit(opt.conf) completer.update_models(manager.list_models()) @@ -1010,6 +1074,7 @@ def get_next_command(infile=None, model_name='no model') -> str: # command stri def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan): print('\n* --web was specified, starting web server...') from invokeai.backend import InvokeAIWebServer + # Change working directory to the stable-diffusion directory os.chdir( os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) @@ -1158,8 +1223,7 @@ def report_model_error(opt:Namespace, e:Exception): if yes_to_all: print('** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE') else: - response = input('Do you want to run invokeai-configure script to select and/or reinstall models? [y] ') - if response.startswith(('n', 'N')): + if click.confirm('Do you want to run invokeai-configure script to select and/or reinstall models?', default=True): return print('invokeai-configure is launching....\n') diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index 27b5d064ef..3421d37717 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -34,8 +34,8 @@ from ldm.invoke.generator.diffusers_pipeline import \ StableDiffusionGeneratorPipeline from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir, global_models_dir) -from ldm.util import (ask_user, download_with_progress_bar, - instantiate_from_config) +from ldm.util import (ask_user, download_with_resume, + url_attachment_name, instantiate_from_config) DEFAULT_MAX_MODELS = 2 VAE_TO_REPO_ID = { # hack, see note in convert_and_import() @@ -673,15 +673,18 @@ class ModelManager(object): path to the configuration file, then the new entry will be committed to the models.yaml file. """ + if str(weights).startswith(("http:", "https:")): + model_name = model_name or url_attachment_name(weights) + weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1") - config_path = self._resolve_path(config, "configs/stable-diffusion") + config_path = self._resolve_path(config, "configs/stable-diffusion") if weights_path is None or not weights_path.exists(): return False if config_path is None or not config_path.exists(): return False - model_name = model_name or Path(weights).stem + model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header model_description = ( model_description or f"imported stable diffusion weights file {model_name}" ) @@ -971,16 +974,15 @@ class ModelManager(object): print("** Migration is done. Continuing...") def _resolve_path( - self, source: Union[str, Path], dest_directory: str + self, source: Union[str, Path], dest_directory: str ) -> Optional[Path]: resolved_path = None if str(source).startswith(("http:", "https:", "ftp:")): - basename = os.path.basename(source) - if not os.path.isabs(dest_directory): - dest_directory = os.path.join(Globals.root, dest_directory) - dest = os.path.join(dest_directory, basename) - if download_with_progress_bar(str(source), Path(dest)): - resolved_path = Path(dest) + dest_directory = Path(dest_directory) + if not dest_directory.is_absolute(): + dest_directory = Globals.root / dest_directory + dest_directory.mkdir(parents=True, exist_ok=True) + resolved_path = download_with_resume(str(source), dest_directory) else: if not os.path.isabs(source): source = os.path.join(Globals.root, source) diff --git a/ldm/util.py b/ldm/util.py index 447875537f..d6d2c9e170 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -1,20 +1,21 @@ import importlib import math import multiprocessing as mp +import os +import re from collections import abc from inspect import isfunction +from pathlib import Path from queue import Queue from threading import Thread -from urllib import request -from tqdm import tqdm -from pathlib import Path -from ldm.invoke.devices import torch_dtype import numpy as np +import requests import torch -import os -import traceback from PIL import Image, ImageDraw, ImageFont +from tqdm import tqdm + +from ldm.invoke.devices import torch_dtype def log_txt_as_img(wh, xc, size=10): @@ -23,18 +24,18 @@ def log_txt_as_img(wh, xc, size=10): b = len(xc) txts = list() for bi in range(b): - txt = Image.new('RGB', wh, color='white') + txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) font = ImageFont.load_default() nc = int(40 * (wh[0] / 256)) - lines = '\n'.join( + lines = "\n".join( xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) ) try: - draw.text((0, 0), lines, fill='black', font=font) + draw.text((0, 0), lines, fill="black", font=font) except UnicodeEncodeError: - print('Cant encode string for logging. Skipping.') + print("Cant encode string for logging. Skipping.") txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) @@ -77,25 +78,23 @@ def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print( - f' | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.' + f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params." ) return total_params def instantiate_from_config(config, **kwargs): - if not 'target' in config: - if config == '__is_first_stage__': + if not "target" in config: + if config == "__is_first_stage__": return None - elif config == '__is_unconditional__': + elif config == "__is_unconditional__": return None - raise KeyError('Expected key `target` to instantiate.') - return get_obj_from_str(config['target'])( - **config.get('params', dict()), **kwargs - ) + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs) def get_obj_from_str(string, reload=False): - module, cls = string.rsplit('.', 1) + module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) @@ -111,14 +110,14 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): else: res = func(data) Q.put([idx, res]) - Q.put('Done') + Q.put("Done") def parallel_data_prefetch( func: callable, data, n_proc, - target_data_type='ndarray', + target_data_type="ndarray", cpu_intensive=True, use_worker_id=False, ): @@ -126,21 +125,21 @@ def parallel_data_prefetch( # raise ValueError( # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." # ) - if isinstance(data, np.ndarray) and target_data_type == 'list': - raise ValueError('list expected but function got ndarray.') + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") elif isinstance(data, abc.Iterable): if isinstance(data, dict): print( - f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + 'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' ) data = list(data.values()) - if target_data_type == 'ndarray': + if target_data_type == "ndarray": data = np.asarray(data) else: data = list(data) else: raise TypeError( - f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.' + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." ) if cpu_intensive: @@ -150,7 +149,7 @@ def parallel_data_prefetch( Q = Queue(1000) proc = Thread # spawn processes - if target_data_type == 'ndarray': + if target_data_type == "ndarray": arguments = [ [func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc)) @@ -173,7 +172,7 @@ def parallel_data_prefetch( processes += [p] # start processes - print(f'Start prefetching...') + print("Start prefetching...") import time start = time.time() @@ -186,13 +185,13 @@ def parallel_data_prefetch( while k < n_proc: # get result res = Q.get() - if res == 'Done': + if res == "Done": k += 1 else: gather_res[res[0]] = res[1] except Exception as e: - print('Exception: ', e) + print("Exception: ", e) for p in processes: p.terminate() @@ -200,15 +199,15 @@ def parallel_data_prefetch( finally: for p in processes: p.join() - print(f'Prefetching complete. [{time.time() - start} sec.]') + print(f"Prefetching complete. [{time.time() - start} sec.]") - if target_data_type == 'ndarray': + if target_data_type == "ndarray": if not isinstance(gather_res[0], np.ndarray): return np.concatenate([np.asarray(r) for r in gather_res], axis=0) # order outputs return np.concatenate(gather_res, axis=0) - elif target_data_type == 'list': + elif target_data_type == "list": out = [] for r in gather_res: out.extend(r) @@ -216,49 +215,79 @@ def parallel_data_prefetch( else: return gather_res -def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3): + +def rand_perlin_2d( + shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3 +): delta = (res[0] / shape[0], res[1] / shape[1]) d = (shape[0] // res[0], shape[1] // res[1]) - grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1 + grid = ( + torch.stack( + torch.meshgrid( + torch.arange(0, res[0], delta[0]), + torch.arange(0, res[1], delta[1]), + indexing="ij", + ), + dim=-1, + ).to(device) + % 1 + ) - rand_val = torch.rand(res[0]+1, res[1]+1) + rand_val = torch.rand(res[0] + 1, res[1] + 1) - angles = 2*math.pi*rand_val - gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1).to(device) + angles = 2 * math.pi * rand_val + gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device) - tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1) + tile_grads = ( + lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] + .repeat_interleave(d[0], 0) + .repeat_interleave(d[1], 1) + ) - dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1) + dot = lambda grad, shift: ( + torch.stack( + ( + grid[: shape[0], : shape[1], 0] + shift[0], + grid[: shape[0], : shape[1], 1] + shift[1], + ), + dim=-1, + ) + * grad[: shape[0], : shape[1]] + ).sum(dim=-1) - n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device) + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device) n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device) - n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device) - n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device) - t = fade(grid[:shape[0], :shape[1]]) - noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device) + n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device) + n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device) + t = fade(grid[: shape[0], : shape[1]]) + noise = math.sqrt(2) * torch.lerp( + torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1] + ).to(device) return noise.to(dtype=torch_dtype(device)) + def ask_user(question: str, answers: list): from itertools import chain, repeat - user_prompt = f'\n>> {question} {answers}: ' - invalid_answer_msg = 'Invalid answer. Please try again.' - pose_question = chain([user_prompt], repeat('\n'.join([invalid_answer_msg, user_prompt]))) + + user_prompt = f"\n>> {question} {answers}: " + invalid_answer_msg = "Invalid answer. Please try again." + pose_question = chain( + [user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt])) + ) user_answers = map(input, pose_question) valid_response = next(filter(answers.__contains__, user_answers)) return valid_response -def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False ): +def debug_image( + debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False +): if not debug_status: return image_copy = debug_image.copy().convert("RGBA") - ImageDraw.Draw(image_copy).text( - (5, 5), - debug_text, - (255, 0, 0) - ) + ImageDraw.Draw(image_copy).text((5, 5), debug_text, (255, 0, 0)) if debug_show: image_copy.show() @@ -266,31 +295,84 @@ def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, de if debug_result: return image_copy -#------------------------------------- -class ProgressBar(): - def __init__(self,model_name='file'): - self.pbar = None - self.name = model_name - def __call__(self, block_num, block_size, total_size): - if not self.pbar: - self.pbar=tqdm(desc=self.name, - initial=0, - unit='iB', - unit_scale=True, - unit_divisor=1000, - total=total_size) - self.pbar.update(block_size) +# ------------------------------------- +def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path: + ''' + Download a model file. + :param url: https, http or ftp URL + :param dest: A Path object. If path exists and is a directory, then we try to derive the filename + from the URL's Content-Disposition header and copy the URL contents into + dest/filename + :param access_token: Access token to access this resource + ''' + resp = requests.get(url, stream=True) + total = int(resp.headers.get("content-length", 0)) + + if dest.is_dir(): + try: + file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1) + except: + file_name = os.path.basename(url) + dest = dest / file_name + else: + dest.parent.mkdir(parents=True, exist_ok=True) + + print(f'DEBUG: after many manipulations, dest={dest}') + + header = {"Authorization": f"Bearer {access_token}"} if access_token else {} + open_mode = "wb" + exist_size = 0 + + if dest.exists(): + exist_size = dest.stat().st_size + header["Range"] = f"bytes={exist_size}-" + open_mode = "ab" + + if ( + resp.status_code == 416 + ): # "range not satisfiable", which means nothing to return + print(f"* {dest}: complete file found. Skipping.") + return dest + elif resp.status_code != 200: + print(f"** An error occurred during downloading {dest}: {resp.reason}") + elif exist_size > 0: + print(f"* {dest}: partial file found. Resuming...") + else: + print(f"* {dest}: Downloading...") -def download_with_progress_bar(url:str, dest:Path)->bool: try: - if not dest.exists(): - dest.parent.mkdir(parents=True, exist_ok=True) - request.urlretrieve(url,dest,ProgressBar(dest.stem)) - return True - else: - return True - except OSError: - print(traceback.format_exc()) - return False + if total < 2000: + print(f"*** ERROR DOWNLOADING {url}: {resp.text}") + return None + with open(dest, open_mode) as file, tqdm( + desc=str(dest), + initial=exist_size, + total=total + exist_size, + unit="iB", + unit_scale=True, + unit_divisor=1000, + ) as bar: + for data in resp.iter_content(chunk_size=1024): + size = file.write(data) + bar.update(size) + except Exception as e: + print(f"An error occurred while downloading {dest}: {str(e)}") + return None + + return dest + + +def url_attachment_name(url: str) -> dict: + try: + resp = requests.get(url, stream=True) + match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")) + return match.group(1) + except: + return None + + +def download_with_progress_bar(url: str, dest: Path) -> bool: + result = download_with_resume(url, dest, access_token=None) + return result is not None