InvokeAI/ldm/util.py

215 lines
5.8 KiB
Python
Raw Normal View History

2021-12-21 02:23:41 +00:00
import importlib
import torch
import numpy as np
2022-08-10 14:30:49 +00:00
from collections import abc
from einops import rearrange
from functools import partial
import multiprocessing as mp
from threading import Thread
from queue import Queue
2021-12-21 02:23:41 +00:00
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
2022-08-23 22:26:28 +00:00
2021-12-21 02:23:41 +00:00
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new('RGB', wh, color='white')
2021-12-21 02:23:41 +00:00
draw = ImageDraw.Draw(txt)
2022-08-23 22:26:28 +00:00
font = ImageFont.load_default()
2021-12-21 02:23:41 +00:00
nc = int(40 * (wh[0] / 256))
lines = '\n'.join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
2021-12-21 02:23:41 +00:00
try:
draw.text((0, 0), lines, fill='black', font=font)
2021-12-21 02:23:41 +00:00
except UnicodeEncodeError:
print('Cant encode string for logging. Skipping.')
2021-12-21 02:23:41 +00:00
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
2022-08-10 14:30:49 +00:00
if not isinstance(x, torch.Tensor):
2021-12-21 02:23:41 +00:00
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
)
2021-12-21 02:23:41 +00:00
return total_params
2022-08-23 22:26:28 +00:00
def instantiate_from_config(config, **kwargs):
if not 'target' in config:
2021-12-21 02:23:41 +00:00
if config == '__is_first_stage__':
return None
elif config == '__is_unconditional__':
2021-12-21 02:23:41 +00:00
return None
raise KeyError('Expected key `target` to instantiate.')
return get_obj_from_str(config['target'])(
**config.get('params', dict()), **kwargs
)
2021-12-21 02:23:41 +00:00
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit('.', 1)
2021-12-21 02:23:41 +00:00
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
2022-08-10 14:30:49 +00:00
return getattr(importlib.import_module(module, package=None), cls)
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
# create dummy dataset instance
# run prefetching
if idx_to_fn:
res = func(data, worker_id=idx)
else:
res = func(data)
Q.put([idx, res])
Q.put('Done')
2022-08-10 14:30:49 +00:00
def parallel_data_prefetch(
func: callable,
data,
n_proc,
target_data_type='ndarray',
cpu_intensive=True,
use_worker_id=False,
2022-08-10 14:30:49 +00:00
):
# if target_data_type not in ["ndarray", "list"]:
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if isinstance(data, np.ndarray) and target_data_type == 'list':
raise ValueError('list expected but function got ndarray.')
2022-08-10 14:30:49 +00:00
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == 'ndarray':
2022-08-10 14:30:49 +00:00
data = np.asarray(data)
else:
data = list(data)
else:
raise TypeError(
f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'
2022-08-10 14:30:49 +00:00
)
if cpu_intensive:
Q = mp.Queue(1000)
proc = mp.Process
else:
Q = Queue(1000)
proc = Thread
# spawn processes
if target_data_type == 'ndarray':
2022-08-10 14:30:49 +00:00
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
]
else:
step = (
int(len(data) / n_proc + 1)
if len(data) % n_proc != 0
else int(len(data) / n_proc)
)
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(
[data[i : i + step] for i in range(0, len(data), step)]
2022-08-10 14:30:49 +00:00
)
]
processes = []
for i in range(n_proc):
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
processes += [p]
# start processes
print(f'Start prefetching...')
2022-08-10 14:30:49 +00:00
import time
start = time.time()
gather_res = [[] for _ in range(n_proc)]
try:
for p in processes:
p.start()
k = 0
while k < n_proc:
# get result
res = Q.get()
if res == 'Done':
2022-08-10 14:30:49 +00:00
k += 1
else:
gather_res[res[0]] = res[1]
except Exception as e:
print('Exception: ', e)
2022-08-10 14:30:49 +00:00
for p in processes:
p.terminate()
raise e
finally:
for p in processes:
p.join()
print(f'Prefetching complete. [{time.time() - start} sec.]')
2022-08-10 14:30:49 +00:00
if target_data_type == 'ndarray':
if not isinstance(gather_res[0], np.ndarray):
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
# order outputs
return np.concatenate(gather_res, axis=0)
elif target_data_type == 'list':
out = []
for r in gather_res:
out.extend(r)
return out
else:
return gather_res