mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix broken !import_model downloads
1. Now works with sites that produce lots of redirects, such as CIVITAI 2. Derive name of destination model file from HTTP Content-Disposition header, if present. 3. Swap \\ for / in file paths provided by users, to hopefully fix issues with Windows.
This commit is contained in:
parent
15a9412255
commit
d38e7170fe
@ -1,29 +1,31 @@
|
|||||||
import click
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
import shlex
|
import shlex
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union, List
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
|
||||||
from ldm.invoke.globals import Globals
|
import pyparsing # type: ignore
|
||||||
|
|
||||||
|
import ldm.invoke
|
||||||
from ldm.generate import Generate
|
from ldm.generate import Generate
|
||||||
from ldm.invoke.prompt_parser import PromptParser
|
from ldm.invoke.args import (Args, dream_cmd_from_png, metadata_dumps,
|
||||||
from ldm.invoke.readline import get_completer, Completer
|
metadata_from_png)
|
||||||
from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
|
||||||
from ldm.invoke.image_util import make_grid
|
from ldm.invoke.image_util import make_grid
|
||||||
from ldm.invoke.log import write_log
|
from ldm.invoke.log import write_log
|
||||||
from ldm.invoke.model_manager import ModelManager
|
from ldm.invoke.model_manager import ModelManager
|
||||||
|
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
||||||
import ldm.invoke
|
from ldm.invoke.prompt_parser import PromptParser
|
||||||
import pyparsing # type: ignore
|
from ldm.invoke.readline import Completer, get_completer
|
||||||
|
from ldm.util import url_attachment_name
|
||||||
|
|
||||||
# global used in multiple functions (fix)
|
# global used in multiple functions (fix)
|
||||||
infile = None
|
infile = None
|
||||||
@ -66,11 +68,11 @@ def main():
|
|||||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||||
|
|
||||||
# loading here to avoid long delays on startup
|
# loading here to avoid long delays on startup
|
||||||
from ldm.generate import Generate
|
|
||||||
|
|
||||||
# these two lines prevent a horrible warning message from appearing
|
# these two lines prevent a horrible warning message from appearing
|
||||||
# when the frozen CLIP tokenizer is imported
|
# when the frozen CLIP tokenizer is imported
|
||||||
import transformers # type: ignore
|
import transformers # type: ignore
|
||||||
|
|
||||||
|
from ldm.generate import Generate
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
import diffusers
|
import diffusers
|
||||||
diffusers.logging.set_verbosity_error()
|
diffusers.logging.set_verbosity_error()
|
||||||
@ -574,10 +576,12 @@ def set_default_output_dir(opt:Args, completer:Completer):
|
|||||||
|
|
||||||
|
|
||||||
def import_model(model_path: str, gen, opt, completer):
|
def import_model(model_path: str, gen, opt, completer):
|
||||||
'''
|
"""
|
||||||
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or
|
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
|
||||||
(3) a huggingface repository id
|
(3) a huggingface repository id; or (4) a local directory containing a
|
||||||
'''
|
diffusers model.
|
||||||
|
"""
|
||||||
|
model.path = model_path.replace('\\','/') # windows
|
||||||
model_name = None
|
model_name = None
|
||||||
|
|
||||||
if model_path.startswith(('http:','https:','ftp:')):
|
if model_path.startswith(('http:','https:','ftp:')):
|
||||||
@ -653,7 +657,10 @@ def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]:
|
|||||||
print()
|
print()
|
||||||
return model_names
|
return model_names
|
||||||
|
|
||||||
def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> Optional[str]:
|
def import_diffuser_model(
|
||||||
|
path_or_repo: Union[Path, str], gen, _, completer
|
||||||
|
) -> Optional[str]:
|
||||||
|
path_or_repo = path_or_repo.replace('\\','/') # windows
|
||||||
manager = gen.model_manager
|
manager = gen.model_manager
|
||||||
default_name = Path(path_or_repo).stem
|
default_name = Path(path_or_repo).stem
|
||||||
default_description = f'Imported model {default_name}'
|
default_description = f'Imported model {default_name}'
|
||||||
@ -676,17 +683,16 @@ def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) ->
|
|||||||
return None
|
return None
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]:
|
def import_ckpt_model(
|
||||||
|
path_or_url: Union[Path, str], gen, opt, completer
|
||||||
|
) -> Optional[str]:
|
||||||
|
path_or_url = path_or_url.replace('\\','/')
|
||||||
manager = gen.model_manager
|
manager = gen.model_manager
|
||||||
|
is_a_url = str(path_or_url).startswith(('http:','https:'))
|
||||||
|
base_name = Path(url_attachment_name(path_or_url)).name if is_a_url else Path(path_or_url).name
|
||||||
|
default_name = Path(base_name).stem
|
||||||
|
default_description = f"Imported model {default_name}"
|
||||||
|
|
||||||
if not (config_file := _ask_for_config_file(path_or_url, completer)):
|
|
||||||
return
|
|
||||||
if config_file.stem == 'v2-inference-v':
|
|
||||||
print('** InvokeAI cannot run SD 2.X checkpoints directly. Model will be converted into diffusers format')
|
|
||||||
return optimize_model(path_or_url, gen, opt, completer, config_file)
|
|
||||||
|
|
||||||
default_name = Path(path_or_url).stem
|
|
||||||
default_description = f'Imported model {default_name}'
|
|
||||||
model_name, model_description = _get_model_name_and_desc(
|
model_name, model_description = _get_model_name_and_desc(
|
||||||
manager,
|
manager,
|
||||||
completer,
|
completer,
|
||||||
@ -776,7 +782,8 @@ def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=Fa
|
|||||||
return config_path
|
return config_path
|
||||||
|
|
||||||
|
|
||||||
def optimize_model(model_name_or_path:str, gen, opt, completer, original_config_file: Path=None):
|
def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
|
||||||
|
model_name_or_path = model_name_or_path.replace('\\','/') # windows
|
||||||
manager = gen.model_manager
|
manager = gen.model_manager
|
||||||
ckpt_path = None
|
ckpt_path = None
|
||||||
|
|
||||||
@ -1067,6 +1074,7 @@ def get_next_command(infile=None, model_name='no model') -> str: # command stri
|
|||||||
def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan):
|
def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan):
|
||||||
print('\n* --web was specified, starting web server...')
|
print('\n* --web was specified, starting web server...')
|
||||||
from invokeai.backend import InvokeAIWebServer
|
from invokeai.backend import InvokeAIWebServer
|
||||||
|
|
||||||
# Change working directory to the stable-diffusion directory
|
# Change working directory to the stable-diffusion directory
|
||||||
os.chdir(
|
os.chdir(
|
||||||
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
@ -36,8 +36,8 @@ from ldm.invoke.generator.diffusers_pipeline import \
|
|||||||
StableDiffusionGeneratorPipeline
|
StableDiffusionGeneratorPipeline
|
||||||
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
|
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
|
||||||
global_models_dir)
|
global_models_dir)
|
||||||
from ldm.util import (ask_user, download_with_progress_bar,
|
from ldm.util import (ask_user, download_with_resume,
|
||||||
instantiate_from_config)
|
url_attachment_name, instantiate_from_config)
|
||||||
|
|
||||||
DEFAULT_MAX_MODELS = 2
|
DEFAULT_MAX_MODELS = 2
|
||||||
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
||||||
@ -670,15 +670,18 @@ class ModelManager(object):
|
|||||||
path to the configuration file, then the new entry will be committed to the
|
path to the configuration file, then the new entry will be committed to the
|
||||||
models.yaml file.
|
models.yaml file.
|
||||||
"""
|
"""
|
||||||
|
if str(weights).startswith(("http:", "https:")):
|
||||||
|
model_name = model_name or url_attachment_name(weights)
|
||||||
|
|
||||||
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
|
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
|
||||||
config_path = self._resolve_path(config, "configs/stable-diffusion")
|
config_path = self._resolve_path(config, "configs/stable-diffusion")
|
||||||
|
|
||||||
if weights_path is None or not weights_path.exists():
|
if weights_path is None or not weights_path.exists():
|
||||||
return False
|
return False
|
||||||
if config_path is None or not config_path.exists():
|
if config_path is None or not config_path.exists():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
model_name = model_name or Path(weights).stem
|
model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header
|
||||||
model_description = (
|
model_description = (
|
||||||
model_description or f"imported stable diffusion weights file {model_name}"
|
model_description or f"imported stable diffusion weights file {model_name}"
|
||||||
)
|
)
|
||||||
@ -748,7 +751,6 @@ class ModelManager(object):
|
|||||||
into models.yaml.
|
into models.yaml.
|
||||||
"""
|
"""
|
||||||
new_config = None
|
new_config = None
|
||||||
import transformers
|
|
||||||
|
|
||||||
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
|
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
|
||||||
|
|
||||||
@ -967,16 +969,15 @@ class ModelManager(object):
|
|||||||
print("** Migration is done. Continuing...")
|
print("** Migration is done. Continuing...")
|
||||||
|
|
||||||
def _resolve_path(
|
def _resolve_path(
|
||||||
self, source: Union[str, Path], dest_directory: str
|
self, source: Union[str, Path], dest_directory: str
|
||||||
) -> Optional[Path]:
|
) -> Optional[Path]:
|
||||||
resolved_path = None
|
resolved_path = None
|
||||||
if str(source).startswith(("http:", "https:", "ftp:")):
|
if str(source).startswith(("http:", "https:", "ftp:")):
|
||||||
basename = os.path.basename(source)
|
dest_directory = Path(dest_directory)
|
||||||
if not os.path.isabs(dest_directory):
|
if not dest_directory.is_absolute():
|
||||||
dest_directory = os.path.join(Globals.root, dest_directory)
|
dest_directory = Globals.root / dest_directory
|
||||||
dest = os.path.join(dest_directory, basename)
|
dest_directory.mkdir(parents=True, exist_ok=True)
|
||||||
if download_with_progress_bar(str(source), Path(dest)):
|
resolved_path = download_with_resume(str(source), dest_directory)
|
||||||
resolved_path = Path(dest)
|
|
||||||
else:
|
else:
|
||||||
if not os.path.isabs(source):
|
if not os.path.isabs(source):
|
||||||
source = os.path.join(Globals.root, source)
|
source = os.path.join(Globals.root, source)
|
||||||
|
238
ldm/util.py
238
ldm/util.py
@ -1,20 +1,21 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import math
|
import math
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from collections import abc
|
from collections import abc
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from urllib import request
|
|
||||||
from tqdm import tqdm
|
|
||||||
from pathlib import Path
|
|
||||||
from ldm.invoke.devices import torch_dtype
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import os
|
|
||||||
import traceback
|
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ldm.invoke.devices import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
def log_txt_as_img(wh, xc, size=10):
|
def log_txt_as_img(wh, xc, size=10):
|
||||||
@ -23,18 +24,18 @@ def log_txt_as_img(wh, xc, size=10):
|
|||||||
b = len(xc)
|
b = len(xc)
|
||||||
txts = list()
|
txts = list()
|
||||||
for bi in range(b):
|
for bi in range(b):
|
||||||
txt = Image.new('RGB', wh, color='white')
|
txt = Image.new("RGB", wh, color="white")
|
||||||
draw = ImageDraw.Draw(txt)
|
draw = ImageDraw.Draw(txt)
|
||||||
font = ImageFont.load_default()
|
font = ImageFont.load_default()
|
||||||
nc = int(40 * (wh[0] / 256))
|
nc = int(40 * (wh[0] / 256))
|
||||||
lines = '\n'.join(
|
lines = "\n".join(
|
||||||
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
|
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
draw.text((0, 0), lines, fill='black', font=font)
|
draw.text((0, 0), lines, fill="black", font=font)
|
||||||
except UnicodeEncodeError:
|
except UnicodeEncodeError:
|
||||||
print('Cant encode string for logging. Skipping.')
|
print("Cant encode string for logging. Skipping.")
|
||||||
|
|
||||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||||
txts.append(txt)
|
txts.append(txt)
|
||||||
@ -77,25 +78,23 @@ def count_params(model, verbose=False):
|
|||||||
total_params = sum(p.numel() for p in model.parameters())
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
print(
|
||||||
f' | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
|
f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
||||||
)
|
)
|
||||||
return total_params
|
return total_params
|
||||||
|
|
||||||
|
|
||||||
def instantiate_from_config(config, **kwargs):
|
def instantiate_from_config(config, **kwargs):
|
||||||
if not 'target' in config:
|
if not "target" in config:
|
||||||
if config == '__is_first_stage__':
|
if config == "__is_first_stage__":
|
||||||
return None
|
return None
|
||||||
elif config == '__is_unconditional__':
|
elif config == "__is_unconditional__":
|
||||||
return None
|
return None
|
||||||
raise KeyError('Expected key `target` to instantiate.')
|
raise KeyError("Expected key `target` to instantiate.")
|
||||||
return get_obj_from_str(config['target'])(
|
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
|
||||||
**config.get('params', dict()), **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_obj_from_str(string, reload=False):
|
def get_obj_from_str(string, reload=False):
|
||||||
module, cls = string.rsplit('.', 1)
|
module, cls = string.rsplit(".", 1)
|
||||||
if reload:
|
if reload:
|
||||||
module_imp = importlib.import_module(module)
|
module_imp = importlib.import_module(module)
|
||||||
importlib.reload(module_imp)
|
importlib.reload(module_imp)
|
||||||
@ -111,14 +110,14 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
|
|||||||
else:
|
else:
|
||||||
res = func(data)
|
res = func(data)
|
||||||
Q.put([idx, res])
|
Q.put([idx, res])
|
||||||
Q.put('Done')
|
Q.put("Done")
|
||||||
|
|
||||||
|
|
||||||
def parallel_data_prefetch(
|
def parallel_data_prefetch(
|
||||||
func: callable,
|
func: callable,
|
||||||
data,
|
data,
|
||||||
n_proc,
|
n_proc,
|
||||||
target_data_type='ndarray',
|
target_data_type="ndarray",
|
||||||
cpu_intensive=True,
|
cpu_intensive=True,
|
||||||
use_worker_id=False,
|
use_worker_id=False,
|
||||||
):
|
):
|
||||||
@ -126,21 +125,21 @@ def parallel_data_prefetch(
|
|||||||
# raise ValueError(
|
# raise ValueError(
|
||||||
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
|
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
|
||||||
# )
|
# )
|
||||||
if isinstance(data, np.ndarray) and target_data_type == 'list':
|
if isinstance(data, np.ndarray) and target_data_type == "list":
|
||||||
raise ValueError('list expected but function got ndarray.')
|
raise ValueError("list expected but function got ndarray.")
|
||||||
elif isinstance(data, abc.Iterable):
|
elif isinstance(data, abc.Iterable):
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
print(
|
print(
|
||||||
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||||
)
|
)
|
||||||
data = list(data.values())
|
data = list(data.values())
|
||||||
if target_data_type == 'ndarray':
|
if target_data_type == "ndarray":
|
||||||
data = np.asarray(data)
|
data = np.asarray(data)
|
||||||
else:
|
else:
|
||||||
data = list(data)
|
data = list(data)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'
|
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cpu_intensive:
|
if cpu_intensive:
|
||||||
@ -150,7 +149,7 @@ def parallel_data_prefetch(
|
|||||||
Q = Queue(1000)
|
Q = Queue(1000)
|
||||||
proc = Thread
|
proc = Thread
|
||||||
# spawn processes
|
# spawn processes
|
||||||
if target_data_type == 'ndarray':
|
if target_data_type == "ndarray":
|
||||||
arguments = [
|
arguments = [
|
||||||
[func, Q, part, i, use_worker_id]
|
[func, Q, part, i, use_worker_id]
|
||||||
for i, part in enumerate(np.array_split(data, n_proc))
|
for i, part in enumerate(np.array_split(data, n_proc))
|
||||||
@ -173,7 +172,7 @@ def parallel_data_prefetch(
|
|||||||
processes += [p]
|
processes += [p]
|
||||||
|
|
||||||
# start processes
|
# start processes
|
||||||
print(f'Start prefetching...')
|
print("Start prefetching...")
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -186,13 +185,13 @@ def parallel_data_prefetch(
|
|||||||
while k < n_proc:
|
while k < n_proc:
|
||||||
# get result
|
# get result
|
||||||
res = Q.get()
|
res = Q.get()
|
||||||
if res == 'Done':
|
if res == "Done":
|
||||||
k += 1
|
k += 1
|
||||||
else:
|
else:
|
||||||
gather_res[res[0]] = res[1]
|
gather_res[res[0]] = res[1]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('Exception: ', e)
|
print("Exception: ", e)
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.terminate()
|
p.terminate()
|
||||||
|
|
||||||
@ -200,15 +199,15 @@ def parallel_data_prefetch(
|
|||||||
finally:
|
finally:
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.join()
|
p.join()
|
||||||
print(f'Prefetching complete. [{time.time() - start} sec.]')
|
print(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||||
|
|
||||||
if target_data_type == 'ndarray':
|
if target_data_type == "ndarray":
|
||||||
if not isinstance(gather_res[0], np.ndarray):
|
if not isinstance(gather_res[0], np.ndarray):
|
||||||
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
|
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
|
||||||
|
|
||||||
# order outputs
|
# order outputs
|
||||||
return np.concatenate(gather_res, axis=0)
|
return np.concatenate(gather_res, axis=0)
|
||||||
elif target_data_type == 'list':
|
elif target_data_type == "list":
|
||||||
out = []
|
out = []
|
||||||
for r in gather_res:
|
for r in gather_res:
|
||||||
out.extend(r)
|
out.extend(r)
|
||||||
@ -216,49 +215,79 @@ def parallel_data_prefetch(
|
|||||||
else:
|
else:
|
||||||
return gather_res
|
return gather_res
|
||||||
|
|
||||||
def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
|
|
||||||
|
def rand_perlin_2d(
|
||||||
|
shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3
|
||||||
|
):
|
||||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||||
d = (shape[0] // res[0], shape[1] // res[1])
|
d = (shape[0] // res[0], shape[1] // res[1])
|
||||||
|
|
||||||
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1
|
grid = (
|
||||||
|
torch.stack(
|
||||||
|
torch.meshgrid(
|
||||||
|
torch.arange(0, res[0], delta[0]),
|
||||||
|
torch.arange(0, res[1], delta[1]),
|
||||||
|
indexing="ij",
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
).to(device)
|
||||||
|
% 1
|
||||||
|
)
|
||||||
|
|
||||||
rand_val = torch.rand(res[0]+1, res[1]+1)
|
rand_val = torch.rand(res[0] + 1, res[1] + 1)
|
||||||
|
|
||||||
angles = 2*math.pi*rand_val
|
angles = 2 * math.pi * rand_val
|
||||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1).to(device)
|
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
|
||||||
|
|
||||||
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
|
tile_grads = (
|
||||||
|
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
||||||
|
.repeat_interleave(d[0], 0)
|
||||||
|
.repeat_interleave(d[1], 1)
|
||||||
|
)
|
||||||
|
|
||||||
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
|
dot = lambda grad, shift: (
|
||||||
|
torch.stack(
|
||||||
|
(
|
||||||
|
grid[: shape[0], : shape[1], 0] + shift[0],
|
||||||
|
grid[: shape[0], : shape[1], 1] + shift[1],
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
* grad[: shape[0], : shape[1]]
|
||||||
|
).sum(dim=-1)
|
||||||
|
|
||||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
|
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
|
||||||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
|
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
|
||||||
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
|
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device)
|
||||||
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
|
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device)
|
||||||
t = fade(grid[:shape[0], :shape[1]])
|
t = fade(grid[: shape[0], : shape[1]])
|
||||||
noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
|
noise = math.sqrt(2) * torch.lerp(
|
||||||
|
torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
|
||||||
|
).to(device)
|
||||||
return noise.to(dtype=torch_dtype(device))
|
return noise.to(dtype=torch_dtype(device))
|
||||||
|
|
||||||
|
|
||||||
def ask_user(question: str, answers: list):
|
def ask_user(question: str, answers: list):
|
||||||
from itertools import chain, repeat
|
from itertools import chain, repeat
|
||||||
user_prompt = f'\n>> {question} {answers}: '
|
|
||||||
invalid_answer_msg = 'Invalid answer. Please try again.'
|
user_prompt = f"\n>> {question} {answers}: "
|
||||||
pose_question = chain([user_prompt], repeat('\n'.join([invalid_answer_msg, user_prompt])))
|
invalid_answer_msg = "Invalid answer. Please try again."
|
||||||
|
pose_question = chain(
|
||||||
|
[user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt]))
|
||||||
|
)
|
||||||
user_answers = map(input, pose_question)
|
user_answers = map(input, pose_question)
|
||||||
valid_response = next(filter(answers.__contains__, user_answers))
|
valid_response = next(filter(answers.__contains__, user_answers))
|
||||||
return valid_response
|
return valid_response
|
||||||
|
|
||||||
|
|
||||||
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False ):
|
def debug_image(
|
||||||
|
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
|
||||||
|
):
|
||||||
if not debug_status:
|
if not debug_status:
|
||||||
return
|
return
|
||||||
|
|
||||||
image_copy = debug_image.copy().convert("RGBA")
|
image_copy = debug_image.copy().convert("RGBA")
|
||||||
ImageDraw.Draw(image_copy).text(
|
ImageDraw.Draw(image_copy).text((5, 5), debug_text, (255, 0, 0))
|
||||||
(5, 5),
|
|
||||||
debug_text,
|
|
||||||
(255, 0, 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
if debug_show:
|
if debug_show:
|
||||||
image_copy.show()
|
image_copy.show()
|
||||||
@ -266,31 +295,84 @@ def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, de
|
|||||||
if debug_result:
|
if debug_result:
|
||||||
return image_copy
|
return image_copy
|
||||||
|
|
||||||
#-------------------------------------
|
|
||||||
class ProgressBar():
|
|
||||||
def __init__(self,model_name='file'):
|
|
||||||
self.pbar = None
|
|
||||||
self.name = model_name
|
|
||||||
|
|
||||||
def __call__(self, block_num, block_size, total_size):
|
# -------------------------------------
|
||||||
if not self.pbar:
|
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
|
||||||
self.pbar=tqdm(desc=self.name,
|
'''
|
||||||
initial=0,
|
Download a model file.
|
||||||
unit='iB',
|
:param url: https, http or ftp URL
|
||||||
unit_scale=True,
|
:param dest: A Path object. If path exists and is a directory, then we try to derive the filename
|
||||||
unit_divisor=1000,
|
from the URL's Content-Disposition header and copy the URL contents into
|
||||||
total=total_size)
|
dest/filename
|
||||||
self.pbar.update(block_size)
|
:param access_token: Access token to access this resource
|
||||||
|
'''
|
||||||
|
resp = requests.get(url, stream=True)
|
||||||
|
total = int(resp.headers.get("content-length", 0))
|
||||||
|
|
||||||
|
if dest.is_dir():
|
||||||
|
try:
|
||||||
|
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
|
||||||
|
except:
|
||||||
|
file_name = os.path.basename(url)
|
||||||
|
dest = dest / file_name
|
||||||
|
else:
|
||||||
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f'DEBUG: after many manipulations, dest={dest}')
|
||||||
|
|
||||||
|
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||||
|
open_mode = "wb"
|
||||||
|
exist_size = 0
|
||||||
|
|
||||||
|
if dest.exists():
|
||||||
|
exist_size = dest.stat().st_size
|
||||||
|
header["Range"] = f"bytes={exist_size}-"
|
||||||
|
open_mode = "ab"
|
||||||
|
|
||||||
|
if (
|
||||||
|
resp.status_code == 416
|
||||||
|
): # "range not satisfiable", which means nothing to return
|
||||||
|
print(f"* {dest}: complete file found. Skipping.")
|
||||||
|
return dest
|
||||||
|
elif resp.status_code != 200:
|
||||||
|
print(f"** An error occurred during downloading {dest}: {resp.reason}")
|
||||||
|
elif exist_size > 0:
|
||||||
|
print(f"* {dest}: partial file found. Resuming...")
|
||||||
|
else:
|
||||||
|
print(f"* {dest}: Downloading...")
|
||||||
|
|
||||||
def download_with_progress_bar(url:str, dest:Path)->bool:
|
|
||||||
try:
|
try:
|
||||||
if not dest.exists():
|
if total < 2000:
|
||||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
|
||||||
request.urlretrieve(url,dest,ProgressBar(dest.stem))
|
return None
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
except OSError:
|
|
||||||
print(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
with open(dest, open_mode) as file, tqdm(
|
||||||
|
desc=str(dest),
|
||||||
|
initial=exist_size,
|
||||||
|
total=total + exist_size,
|
||||||
|
unit="iB",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1000,
|
||||||
|
) as bar:
|
||||||
|
for data in resp.iter_content(chunk_size=1024):
|
||||||
|
size = file.write(data)
|
||||||
|
bar.update(size)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred while downloading {dest}: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return dest
|
||||||
|
|
||||||
|
|
||||||
|
def url_attachment_name(url: str) -> dict:
|
||||||
|
try:
|
||||||
|
resp = requests.get(url, stream=True)
|
||||||
|
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
|
||||||
|
return match.group(1)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def download_with_progress_bar(url: str, dest: Path) -> bool:
|
||||||
|
result = download_with_resume(url, dest, access_token=None)
|
||||||
|
return result is not None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user