diff --git a/docs/installation/050_INSTALLING_MODELS.md b/docs/installation/050_INSTALLING_MODELS.md
index a610491798..5621075506 100644
--- a/docs/installation/050_INSTALLING_MODELS.md
+++ b/docs/installation/050_INSTALLING_MODELS.md
@@ -80,6 +80,13 @@ only `.safetensors` and `.ckpt` models, but they can be easily loaded
into InvokeAI and/or converted into optimized `diffusers` models. Be
aware that CIVITAI hosts many models that generate NSFW content.
+!!! note
+
+ InvokeAI 2.3.x does not support directly importing and
+ running Stable Diffusion version 2 checkpoint models. You may instead
+ convert them into `diffusers` models using the conversion methods
+ described below.
+
## Installation
There are multiple ways to install and manage models:
@@ -90,7 +97,7 @@ There are multiple ways to install and manage models:
models files.
3. The web interface (WebUI) has a GUI for importing and managing
-models.
+ models.
### Installation via `invokeai-configure`
@@ -106,7 +113,7 @@ confirm that the files are complete.
You can install a new model, including any of the community-supported ones, via
the command-line client's `!import_model` command.
-#### Installing `.ckpt` and `.safetensors` models
+#### Installing individual `.ckpt` and `.safetensors` models
If the model is already downloaded to your local disk, use
`!import_model /path/to/file.ckpt` to load it. For example:
@@ -131,15 +138,40 @@ invoke> !import_model https://example.org/sd_models/martians.safetensors
For this to work, the URL must not be password-protected. Otherwise
you will receive a 404 error.
-When you import a legacy model, the CLI will ask you a few questions
-about the model, including what size image it was trained on (usually
-512x512), what name and description you wish to use for it, what
-configuration file to use for it (usually the default
-`v1-inference.yaml`), whether you'd like to make this model the
-default at startup time, and whether you would like to install a
-custom VAE (variable autoencoder) file for the model. For recent
-models, the answer to the VAE question is usually "no," but it won't
-hurt to answer "yes".
+When you import a legacy model, the CLI will first ask you what type
+of model this is. You can indicate whether it is a model based on
+Stable Diffusion 1.x (1.4 or 1.5), one based on Stable Diffusion 2.x,
+or a 1.x inpainting model. Be careful to indicate the correct model
+type, or it will not load correctly. You can correct the model type
+after the fact using the `!edit_model` command.
+
+The system will then ask you a few other questions about the model,
+including what size image it was trained on (usually 512x512), what
+name and description you wish to use for it, and whether you would
+like to install a custom VAE (variable autoencoder) file for the
+model. For recent models, the answer to the VAE question is usually
+"no," but it won't hurt to answer "yes".
+
+After importing, the model will load. If this is successful, you will
+be asked if you want to keep the model loaded in memory to start
+generating immediately. You'll also be asked if you wish to make this
+the default model on startup. You can change this later using
+`!edit_model`.
+
+#### Importing a batch of `.ckpt` and `.safetensors` models from a directory
+
+You may also point `!import_model` to a directory containing a set of
+`.ckpt` or `.safetensors` files. They will be imported _en masse_.
+
+!!! example
+
+ ```console
+ invoke> !import_model C:/Users/fred/Downloads/civitai_models/
+ ```
+
+You will be given the option to import all models found in the
+directory, or select which ones to import. If there are subfolders
+within the directory, they will be searched for models to import.
#### Installing `diffusers` models
@@ -284,14 +316,18 @@ up a dialogue that lists the models you have already installed, and
allows you to load, delete or edit them:
To add a new model, click on **+ Add New** and select to either a
checkpoint/safetensors model, or a diffusers model:
In this example, we chose **Add Diffusers**. As shown in the figure
@@ -302,7 +338,9 @@ choose to enter a path to disk, the system will autocomplete for you
as you type:
Press **Add Model** at the bottom of the dialogue (scrolled out of
@@ -317,7 +355,9 @@ directory and press the "Search" icon. This will display the
subfolders, and allow you to choose which ones to import:
## Model Management Startup Options
@@ -342,9 +382,8 @@ invoke.sh --autoconvert /home/fred/stable-diffusion-checkpoints
And here is what the same argument looks like in `invokeai.init`:
-```
+```bash
--outdir="/home/fred/invokeai/outputs
--no-nsfw_checker
--autoconvert /home/fred/stable-diffusion-checkpoints
```
-
diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py
index 32c6d816be..f72f6058aa 100644
--- a/ldm/invoke/CLI.py
+++ b/ldm/invoke/CLI.py
@@ -1,29 +1,31 @@
import os
import re
-import sys
import shlex
+import sys
import traceback
-
from argparse import Namespace
from pathlib import Path
-from typing import Optional, Union
+from typing import List, Optional, Union
+
+import click
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
-from ldm.invoke.globals import Globals
+import pyparsing # type: ignore
+
+import ldm.invoke
from ldm.generate import Generate
-from ldm.invoke.prompt_parser import PromptParser
-from ldm.invoke.readline import get_completer, Completer
-from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
-from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
+from ldm.invoke.args import (Args, dream_cmd_from_png, metadata_dumps,
+ metadata_from_png)
+from ldm.invoke.globals import Globals
from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log
from ldm.invoke.model_manager import ModelManager
-
-import click # type: ignore
-import ldm.invoke
-import pyparsing # type: ignore
+from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
+from ldm.invoke.prompt_parser import PromptParser
+from ldm.invoke.readline import Completer, get_completer
+from ldm.util import url_attachment_name
# global used in multiple functions (fix)
infile = None
@@ -66,11 +68,11 @@ def main():
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# loading here to avoid long delays on startup
- from ldm.generate import Generate
-
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers # type: ignore
+
+ from ldm.generate import Generate
transformers.logging.set_verbosity_error()
import diffusers
diffusers.logging.set_verbosity_error()
@@ -574,10 +576,12 @@ def set_default_output_dir(opt:Args, completer:Completer):
def import_model(model_path: str, gen, opt, completer):
- '''
- model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or
- (3) a huggingface repository id
- '''
+ """
+ model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
+ (3) a huggingface repository id; or (4) a local directory containing a
+ diffusers model.
+ """
+ model.path = model_path.replace('\\','/') # windows
model_name = None
if model_path.startswith(('http:','https:','ftp:')):
@@ -592,12 +596,8 @@ def import_model(model_path: str, gen, opt, completer):
models = list(Path(model_path).rglob('*.ckpt')) + list(Path(model_path).rglob('*.safetensors'))
if models:
- # Only the last model name will be used below.
- for model in sorted(models):
-
- if click.confirm(f'Import {model.stem} ?', default=True):
- model_name = import_ckpt_model(model, gen, opt, completer)
- print()
+ models = import_checkpoint_list(models, gen, opt, completer)
+ model_name = models[0] if len(models) == 1 else None
else:
model_name = import_diffuser_model(Path(model_path), gen, opt, completer)
@@ -614,14 +614,53 @@ def import_model(model_path: str, gen, opt, completer):
print('** model failed to load. Discarding configuration entry')
gen.model_manager.del_model(model_name)
return
- if input('Make this the default model? [n] ').strip() in ('y','Y'):
+ if click.confirm('Make this the default model?', default=False):
gen.model_manager.set_default_model(model_name)
gen.model_manager.commit(opt.conf)
completer.update_models(gen.model_manager.list_models())
print(f'>> {model_name} successfully installed')
-def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> Optional[str]:
+def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]:
+ '''
+ Does a mass import of all the checkpoint/safetensors on a path list
+ '''
+ model_names = list()
+ choice = input('** Directory of checkpoint/safetensors models detected. Install ll or elected models? [a] ') or 'a'
+ do_all = choice.startswith('a')
+ if do_all:
+ config_file = _ask_for_config_file(models[0], completer, plural=True)
+ manager = gen.model_manager
+ for model in sorted(models):
+ model_name = f'{model.stem}'
+ model_description = f'Imported model {model_name}'
+ if model_name in manager.model_names():
+ print(f'** {model_name} is already imported. Skipping.')
+ elif manager.import_ckpt_model(
+ model,
+ config = config_file,
+ model_name = model_name,
+ model_description = model_description,
+ commit_to_conf = opt.conf):
+ model_names.append(model_name)
+ print(f'>> Model {model_name} imported successfully')
+ else:
+ print(f'** Model {model} failed to import')
+ else:
+ for model in sorted(models):
+ if click.confirm(f'Import {model.stem} ?', default=True):
+ if model_name := import_ckpt_model(model, gen, opt, completer):
+ print(f'>> Model {model.stem} imported successfully')
+ model_names.append(model_name)
+ else:
+ printf('** Model {model} failed to import')
+ print()
+ return model_names
+
+def import_diffuser_model(
+ path_or_repo: Union[Path, str], gen, _, completer
+) -> Optional[str]:
+ path_or_repo = path_or_repo.replace('\\','/') # windows
manager = gen.model_manager
default_name = Path(path_or_repo).stem
default_description = f'Imported model {default_name}'
@@ -632,7 +671,7 @@ def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) ->
model_description=default_description
)
vae = None
- if input('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ').strip() in ('y','Y'):
+ if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
if not manager.import_diffuser_model(
@@ -644,27 +683,22 @@ def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) ->
return None
return model_name
-def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]:
+def import_ckpt_model(
+ path_or_url: Union[Path, str], gen, opt, completer
+) -> Optional[str]:
+ path_or_url = path_or_url.replace('\\','/')
manager = gen.model_manager
- default_name = Path(path_or_url).stem
- default_description = f'Imported model {default_name}'
+ is_a_url = str(path_or_url).startswith(('http:','https:'))
+ base_name = Path(url_attachment_name(path_or_url)).name if is_a_url else Path(path_or_url).name
+ default_name = Path(base_name).stem
+ default_description = f"Imported model {default_name}"
+
model_name, model_description = _get_model_name_and_desc(
manager,
completer,
model_name=default_name,
model_description=default_description
)
- config_file = None
- default = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml') \
- if re.search('inpaint',default_name, flags=re.IGNORECASE) \
- else Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
-
- completer.complete_extensions(('.yaml','.yml'))
- completer.set_line(str(default))
- done = False
- while not done:
- config_file = input('Configuration file for this model: ').strip()
- done = os.path.exists(config_file)
completer.complete_extensions(('.ckpt','.safetensors'))
vae = None
@@ -692,10 +726,15 @@ def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Opt
def _verify_load(model_name:str, gen)->bool:
print('>> Verifying that new model loads...')
current_model = gen.model_name
- if not gen.model_manager.get_model(model_name):
+ try:
+ if not gen.model_manager.get_model(model_name):
+ return False
+ except Exception as e:
+ print(f'** model failed to load: {str(e)}')
+ print('** note that importing 2.X checkpoints is not supported. Please use !convert_model instead.')
return False
- do_switch = input('Keep model loaded? [y] ')
- if len(do_switch)==0 or do_switch[0] in ('y','Y'):
+
+ if click.confirm('Keep model loaded?', default=True):
gen.set_model(model_name)
else:
print('>> Restoring previous model')
@@ -708,16 +747,45 @@ def _get_model_name_and_desc(model_manager,completer,model_name:str='',model_des
model_description = input(f'Description for this model [{model_description}]: ').strip() or model_description
return model_name, model_description
-def _is_inpainting(model_name_or_path: str)->bool:
- if re.search('inpaint',model_name_or_path, flags=re.IGNORECASE):
- return not input('Is this an inpainting model? [y] ').startswith(('n','N'))
- else:
- return not input('Is this an inpainting model? [n] ').startswith(('y','Y'))
+def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=False)->Path:
+ default = '1'
+ if re.search('inpaint',str(model_path),flags=re.IGNORECASE):
+ default = '3'
+ choices={
+ '1': 'v1-inference.yaml',
+ '2': 'v2-inference-v.yaml',
+ '3': 'v1-inpainting-inference.yaml',
+ }
+
+ prompt = '''What type of models are these?:
+[1] Models based on Stable Diffusion 1.X
+[2] Models based on Stable Diffusion 2.X
+[3] Inpainting models based on Stable Diffusion 1.X
+[4] Something else''' if plural else '''What type of model is this?:
+[1] A model based on Stable Diffusion 1.X
+[2] A model based on Stable Diffusion 2.X
+[3] An inpainting models based on Stable Diffusion 1.X
+[4] Something else'''
+ print(prompt)
+ choice = input(f'Your choice: [{default}] ')
+ choice = choice.strip() or default
+ if config_file := choices.get(choice,None):
+ return Path('configs','stable-diffusion',config_file)
-def optimize_model(model_name_or_path: str, gen, opt, completer):
+ # otherwise ask user to select
+ done = False
+ completer.complete_extensions(('.yaml','.yml'))
+ completer.set_line(str(Path(Globals.root,'configs/stable-diffusion/')))
+ while not done:
+ config_path = input('Configuration file for this model (leave blank to abort): ').strip()
+ done = not config_path or os.path.exists(config_path)
+ return config_path
+
+
+def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
+ model_name_or_path = model_name_or_path.replace('\\','/') # windows
manager = gen.model_manager
ckpt_path = None
- original_config_file = None
if model_name_or_path == gen.model_name:
print("** Can't convert the active model. !switch to another model first. **")
@@ -732,6 +800,9 @@ def optimize_model(model_name_or_path: str, gen, opt, completer):
print(f'** {model_name_or_path} is not a legacy .ckpt weights file')
return
elif os.path.exists(model_name_or_path):
+ original_config_file = original_config_file or _ask_for_config_file(model_name_or_path, completer)
+ if not original_config_file:
+ return
ckpt_path = Path(model_name_or_path)
model_name, model_description = _get_model_name_and_desc(
manager,
@@ -739,12 +810,6 @@ def optimize_model(model_name_or_path: str, gen, opt, completer):
ckpt_path.stem,
f'Converted model {ckpt_path.stem}'
)
- is_inpainting = _is_inpainting(model_name_or_path)
- original_config_file = Path(
- 'configs',
- 'stable-diffusion',
- 'v1-inpainting-inference.yaml' if is_inpainting else 'v1-inference.yaml'
- )
else:
print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file')
return
@@ -761,7 +826,7 @@ def optimize_model(model_name_or_path: str, gen, opt, completer):
return
vae = None
- if input('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ').strip() in ('y','Y'):
+ if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
new_config = gen.model_manager.convert_and_import(
@@ -777,11 +842,10 @@ def optimize_model(model_name_or_path: str, gen, opt, completer):
return
completer.update_models(gen.model_manager.list_models())
- if input(f'Load optimized model {model_name}? [y] ').strip() not in ('n','N'):
+ if click.confirm(f'Load optimized model {model_name}?', default=True):
gen.set_model(model_name)
- response = input(f'Delete the original .ckpt file at ({ckpt_path} ? [n] ')
- if response.startswith(('y','Y')):
+ if click.confirm(f'Delete the original .ckpt file at {ckpt_path}?',default=False):
ckpt_path.unlink(missing_ok=True)
print(f'{ckpt_path} deleted')
@@ -794,10 +858,10 @@ def del_config(model_name:str, gen, opt, completer):
print(f"** Unknown model {model_name}")
return
- if input(f'Remove {model_name} from the list of models known to InvokeAI? [y] ').strip().startswith(('n','N')):
+ if not click.confirm(f'Remove {model_name} from the list of models known to InvokeAI?',default=True):
return
- delete_completely = input('Completely remove the model file or directory from disk? [n] ').startswith(('y','Y'))
+ delete_completely = click.confirm('Completely remove the model file or directory from disk?',default=False)
gen.model_manager.del_model(model_name,delete_files=delete_completely)
gen.model_manager.commit(opt.conf)
print(f'** {model_name} deleted')
@@ -826,7 +890,7 @@ def edit_model(model_name:str, gen, opt, completer):
# this does the update
manager.add_model(new_name, info, True)
- if input('Make this the default model? [n] ').startswith(('y','Y')):
+ if click.confirm('Make this the default model?',default=False):
manager.set_default_model(new_name)
manager.commit(opt.conf)
completer.update_models(manager.list_models())
@@ -1010,6 +1074,7 @@ def get_next_command(infile=None, model_name='no model') -> str: # command stri
def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan):
print('\n* --web was specified, starting web server...')
from invokeai.backend import InvokeAIWebServer
+
# Change working directory to the stable-diffusion directory
os.chdir(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
@@ -1158,8 +1223,7 @@ def report_model_error(opt:Namespace, e:Exception):
if yes_to_all:
print('** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE')
else:
- response = input('Do you want to run invokeai-configure script to select and/or reinstall models? [y] ')
- if response.startswith(('n', 'N')):
+ if click.confirm('Do you want to run invokeai-configure script to select and/or reinstall models?', default=True):
return
print('invokeai-configure is launching....\n')
diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py
index 27b5d064ef..3421d37717 100644
--- a/ldm/invoke/model_manager.py
+++ b/ldm/invoke/model_manager.py
@@ -34,8 +34,8 @@ from ldm.invoke.generator.diffusers_pipeline import \
StableDiffusionGeneratorPipeline
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
global_models_dir)
-from ldm.util import (ask_user, download_with_progress_bar,
- instantiate_from_config)
+from ldm.util import (ask_user, download_with_resume,
+ url_attachment_name, instantiate_from_config)
DEFAULT_MAX_MODELS = 2
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
@@ -673,15 +673,18 @@ class ModelManager(object):
path to the configuration file, then the new entry will be committed to the
models.yaml file.
"""
+ if str(weights).startswith(("http:", "https:")):
+ model_name = model_name or url_attachment_name(weights)
+
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
- config_path = self._resolve_path(config, "configs/stable-diffusion")
+ config_path = self._resolve_path(config, "configs/stable-diffusion")
if weights_path is None or not weights_path.exists():
return False
if config_path is None or not config_path.exists():
return False
- model_name = model_name or Path(weights).stem
+ model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header
model_description = (
model_description or f"imported stable diffusion weights file {model_name}"
)
@@ -971,16 +974,15 @@ class ModelManager(object):
print("** Migration is done. Continuing...")
def _resolve_path(
- self, source: Union[str, Path], dest_directory: str
+ self, source: Union[str, Path], dest_directory: str
) -> Optional[Path]:
resolved_path = None
if str(source).startswith(("http:", "https:", "ftp:")):
- basename = os.path.basename(source)
- if not os.path.isabs(dest_directory):
- dest_directory = os.path.join(Globals.root, dest_directory)
- dest = os.path.join(dest_directory, basename)
- if download_with_progress_bar(str(source), Path(dest)):
- resolved_path = Path(dest)
+ dest_directory = Path(dest_directory)
+ if not dest_directory.is_absolute():
+ dest_directory = Globals.root / dest_directory
+ dest_directory.mkdir(parents=True, exist_ok=True)
+ resolved_path = download_with_resume(str(source), dest_directory)
else:
if not os.path.isabs(source):
source = os.path.join(Globals.root, source)
diff --git a/ldm/util.py b/ldm/util.py
index 447875537f..d6d2c9e170 100644
--- a/ldm/util.py
+++ b/ldm/util.py
@@ -1,20 +1,21 @@
import importlib
import math
import multiprocessing as mp
+import os
+import re
from collections import abc
from inspect import isfunction
+from pathlib import Path
from queue import Queue
from threading import Thread
-from urllib import request
-from tqdm import tqdm
-from pathlib import Path
-from ldm.invoke.devices import torch_dtype
import numpy as np
+import requests
import torch
-import os
-import traceback
from PIL import Image, ImageDraw, ImageFont
+from tqdm import tqdm
+
+from ldm.invoke.devices import torch_dtype
def log_txt_as_img(wh, xc, size=10):
@@ -23,18 +24,18 @@ def log_txt_as_img(wh, xc, size=10):
b = len(xc)
txts = list()
for bi in range(b):
- txt = Image.new('RGB', wh, color='white')
+ txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256))
- lines = '\n'.join(
+ lines = "\n".join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
try:
- draw.text((0, 0), lines, fill='black', font=font)
+ draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
- print('Cant encode string for logging. Skipping.')
+ print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@@ -77,25 +78,23 @@ def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
- f' | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
+ f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
)
return total_params
def instantiate_from_config(config, **kwargs):
- if not 'target' in config:
- if config == '__is_first_stage__':
+ if not "target" in config:
+ if config == "__is_first_stage__":
return None
- elif config == '__is_unconditional__':
+ elif config == "__is_unconditional__":
return None
- raise KeyError('Expected key `target` to instantiate.')
- return get_obj_from_str(config['target'])(
- **config.get('params', dict()), **kwargs
- )
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
def get_obj_from_str(string, reload=False):
- module, cls = string.rsplit('.', 1)
+ module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
@@ -111,14 +110,14 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
else:
res = func(data)
Q.put([idx, res])
- Q.put('Done')
+ Q.put("Done")
def parallel_data_prefetch(
func: callable,
data,
n_proc,
- target_data_type='ndarray',
+ target_data_type="ndarray",
cpu_intensive=True,
use_worker_id=False,
):
@@ -126,21 +125,21 @@ def parallel_data_prefetch(
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
- if isinstance(data, np.ndarray) and target_data_type == 'list':
- raise ValueError('list expected but function got ndarray.')
+ if isinstance(data, np.ndarray) and target_data_type == "list":
+ raise ValueError("list expected but function got ndarray.")
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
- f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
+ 'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
- if target_data_type == 'ndarray':
+ if target_data_type == "ndarray":
data = np.asarray(data)
else:
data = list(data)
else:
raise TypeError(
- f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
)
if cpu_intensive:
@@ -150,7 +149,7 @@ def parallel_data_prefetch(
Q = Queue(1000)
proc = Thread
# spawn processes
- if target_data_type == 'ndarray':
+ if target_data_type == "ndarray":
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
@@ -173,7 +172,7 @@ def parallel_data_prefetch(
processes += [p]
# start processes
- print(f'Start prefetching...')
+ print("Start prefetching...")
import time
start = time.time()
@@ -186,13 +185,13 @@ def parallel_data_prefetch(
while k < n_proc:
# get result
res = Q.get()
- if res == 'Done':
+ if res == "Done":
k += 1
else:
gather_res[res[0]] = res[1]
except Exception as e:
- print('Exception: ', e)
+ print("Exception: ", e)
for p in processes:
p.terminate()
@@ -200,15 +199,15 @@ def parallel_data_prefetch(
finally:
for p in processes:
p.join()
- print(f'Prefetching complete. [{time.time() - start} sec.]')
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
- if target_data_type == 'ndarray':
+ if target_data_type == "ndarray":
if not isinstance(gather_res[0], np.ndarray):
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
# order outputs
return np.concatenate(gather_res, axis=0)
- elif target_data_type == 'list':
+ elif target_data_type == "list":
out = []
for r in gather_res:
out.extend(r)
@@ -216,49 +215,79 @@ def parallel_data_prefetch(
else:
return gather_res
-def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
+
+def rand_perlin_2d(
+ shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3
+):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
- grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1
+ grid = (
+ torch.stack(
+ torch.meshgrid(
+ torch.arange(0, res[0], delta[0]),
+ torch.arange(0, res[1], delta[1]),
+ indexing="ij",
+ ),
+ dim=-1,
+ ).to(device)
+ % 1
+ )
- rand_val = torch.rand(res[0]+1, res[1]+1)
+ rand_val = torch.rand(res[0] + 1, res[1] + 1)
- angles = 2*math.pi*rand_val
- gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1).to(device)
+ angles = 2 * math.pi * rand_val
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
- tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
+ tile_grads = (
+ lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
+ .repeat_interleave(d[0], 0)
+ .repeat_interleave(d[1], 1)
+ )
- dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
+ dot = lambda grad, shift: (
+ torch.stack(
+ (
+ grid[: shape[0], : shape[1], 0] + shift[0],
+ grid[: shape[0], : shape[1], 1] + shift[1],
+ ),
+ dim=-1,
+ )
+ * grad[: shape[0], : shape[1]]
+ ).sum(dim=-1)
- n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
- n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
- n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
- t = fade(grid[:shape[0], :shape[1]])
- noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device)
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device)
+ t = fade(grid[: shape[0], : shape[1]])
+ noise = math.sqrt(2) * torch.lerp(
+ torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
+ ).to(device)
return noise.to(dtype=torch_dtype(device))
+
def ask_user(question: str, answers: list):
from itertools import chain, repeat
- user_prompt = f'\n>> {question} {answers}: '
- invalid_answer_msg = 'Invalid answer. Please try again.'
- pose_question = chain([user_prompt], repeat('\n'.join([invalid_answer_msg, user_prompt])))
+
+ user_prompt = f"\n>> {question} {answers}: "
+ invalid_answer_msg = "Invalid answer. Please try again."
+ pose_question = chain(
+ [user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt]))
+ )
user_answers = map(input, pose_question)
valid_response = next(filter(answers.__contains__, user_answers))
return valid_response
-def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False ):
+def debug_image(
+ debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
+):
if not debug_status:
return
image_copy = debug_image.copy().convert("RGBA")
- ImageDraw.Draw(image_copy).text(
- (5, 5),
- debug_text,
- (255, 0, 0)
- )
+ ImageDraw.Draw(image_copy).text((5, 5), debug_text, (255, 0, 0))
if debug_show:
image_copy.show()
@@ -266,31 +295,84 @@ def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, de
if debug_result:
return image_copy
-#-------------------------------------
-class ProgressBar():
- def __init__(self,model_name='file'):
- self.pbar = None
- self.name = model_name
- def __call__(self, block_num, block_size, total_size):
- if not self.pbar:
- self.pbar=tqdm(desc=self.name,
- initial=0,
- unit='iB',
- unit_scale=True,
- unit_divisor=1000,
- total=total_size)
- self.pbar.update(block_size)
+# -------------------------------------
+def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
+ '''
+ Download a model file.
+ :param url: https, http or ftp URL
+ :param dest: A Path object. If path exists and is a directory, then we try to derive the filename
+ from the URL's Content-Disposition header and copy the URL contents into
+ dest/filename
+ :param access_token: Access token to access this resource
+ '''
+ resp = requests.get(url, stream=True)
+ total = int(resp.headers.get("content-length", 0))
+
+ if dest.is_dir():
+ try:
+ file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
+ except:
+ file_name = os.path.basename(url)
+ dest = dest / file_name
+ else:
+ dest.parent.mkdir(parents=True, exist_ok=True)
+
+ print(f'DEBUG: after many manipulations, dest={dest}')
+
+ header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
+ open_mode = "wb"
+ exist_size = 0
+
+ if dest.exists():
+ exist_size = dest.stat().st_size
+ header["Range"] = f"bytes={exist_size}-"
+ open_mode = "ab"
+
+ if (
+ resp.status_code == 416
+ ): # "range not satisfiable", which means nothing to return
+ print(f"* {dest}: complete file found. Skipping.")
+ return dest
+ elif resp.status_code != 200:
+ print(f"** An error occurred during downloading {dest}: {resp.reason}")
+ elif exist_size > 0:
+ print(f"* {dest}: partial file found. Resuming...")
+ else:
+ print(f"* {dest}: Downloading...")
-def download_with_progress_bar(url:str, dest:Path)->bool:
try:
- if not dest.exists():
- dest.parent.mkdir(parents=True, exist_ok=True)
- request.urlretrieve(url,dest,ProgressBar(dest.stem))
- return True
- else:
- return True
- except OSError:
- print(traceback.format_exc())
- return False
+ if total < 2000:
+ print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
+ return None
+ with open(dest, open_mode) as file, tqdm(
+ desc=str(dest),
+ initial=exist_size,
+ total=total + exist_size,
+ unit="iB",
+ unit_scale=True,
+ unit_divisor=1000,
+ ) as bar:
+ for data in resp.iter_content(chunk_size=1024):
+ size = file.write(data)
+ bar.update(size)
+ except Exception as e:
+ print(f"An error occurred while downloading {dest}: {str(e)}")
+ return None
+
+ return dest
+
+
+def url_attachment_name(url: str) -> dict:
+ try:
+ resp = requests.get(url, stream=True)
+ match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
+ return match.group(1)
+ except:
+ return None
+
+
+def download_with_progress_bar(url: str, dest: Path) -> bool:
+ result = download_with_resume(url, dest, access_token=None)
+ return result is not None