mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: remove unused files
This commit is contained in:
parent
26f7adeaa3
commit
9cb04f6f80
@ -1,6 +0,0 @@
|
|||||||
from ldm.modules.image_degradation.bsrgan import ( # noqa: F401
|
|
||||||
degradation_bsrgan_variant as degradation_fn_bsr,
|
|
||||||
)
|
|
||||||
from ldm.modules.image_degradation.bsrgan_light import ( # noqa: F401
|
|
||||||
degradation_bsrgan_variant as degradation_fn_bsr_light,
|
|
||||||
)
|
|
@ -1,794 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# Super-Resolution
|
|
||||||
# --------------------------------------------
|
|
||||||
#
|
|
||||||
# Kai Zhang (cskaizhang@gmail.com)
|
|
||||||
# https://github.com/cszn
|
|
||||||
# From 2019/03--2021/08
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
import random
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import albumentations
|
|
||||||
import cv2
|
|
||||||
import ldm.modules.image_degradation.utils_image as util
|
|
||||||
import numpy as np
|
|
||||||
import scipy
|
|
||||||
import scipy.stats as ss
|
|
||||||
import torch
|
|
||||||
from scipy import ndimage
|
|
||||||
from scipy.interpolate import interp2d
|
|
||||||
from scipy.linalg import orth
|
|
||||||
|
|
||||||
|
|
||||||
def modcrop_np(img, sf):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
img: numpy image, WxH or WxHxC
|
|
||||||
sf: scale factor
|
|
||||||
Return:
|
|
||||||
cropped image
|
|
||||||
"""
|
|
||||||
w, h = img.shape[:2]
|
|
||||||
im = np.copy(img)
|
|
||||||
return im[: w - w % sf, : h - h % sf, ...]
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# anisotropic Gaussian kernels
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def analytic_kernel(k):
|
|
||||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
|
||||||
k_size = k.shape[0]
|
|
||||||
# Calculate the big kernels size
|
|
||||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
|
||||||
# Loop over the small kernel to fill the big one
|
|
||||||
for r in range(k_size):
|
|
||||||
for c in range(k_size):
|
|
||||||
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k
|
|
||||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
|
||||||
crop = k_size // 2
|
|
||||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
|
||||||
# Normalize to 1
|
|
||||||
return cropped_big_k / cropped_big_k.sum()
|
|
||||||
|
|
||||||
|
|
||||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
|
||||||
"""generate an anisotropic Gaussian kernel
|
|
||||||
Args:
|
|
||||||
ksize : e.g., 15, kernel size
|
|
||||||
theta : [0, pi], rotation angle range
|
|
||||||
l1 : [0.1,50], scaling of eigenvalues
|
|
||||||
l2 : [0.1,l1], scaling of eigenvalues
|
|
||||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
|
||||||
Returns:
|
|
||||||
k : kernel
|
|
||||||
"""
|
|
||||||
|
|
||||||
v = np.dot(
|
|
||||||
np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]),
|
|
||||||
np.array([1.0, 0.0]),
|
|
||||||
)
|
|
||||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
|
||||||
D = np.array([[l1, 0], [0, l2]])
|
|
||||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
|
||||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
|
||||||
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
def gm_blur_kernel(mean, cov, size=15):
|
|
||||||
center = size / 2.0 + 0.5
|
|
||||||
k = np.zeros([size, size])
|
|
||||||
for y in range(size):
|
|
||||||
for x in range(size):
|
|
||||||
cy = y - center + 1
|
|
||||||
cx = x - center + 1
|
|
||||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
|
||||||
|
|
||||||
k = k / np.sum(k)
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
def shift_pixel(x, sf, upper_left=True):
|
|
||||||
"""shift pixel for super-resolution with different scale factors
|
|
||||||
Args:
|
|
||||||
x: WxHxC or WxH
|
|
||||||
sf: scale factor
|
|
||||||
upper_left: shift direction
|
|
||||||
"""
|
|
||||||
h, w = x.shape[:2]
|
|
||||||
shift = (sf - 1) * 0.5
|
|
||||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
|
||||||
if upper_left:
|
|
||||||
x1 = xv + shift
|
|
||||||
y1 = yv + shift
|
|
||||||
else:
|
|
||||||
x1 = xv - shift
|
|
||||||
y1 = yv - shift
|
|
||||||
|
|
||||||
x1 = np.clip(x1, 0, w - 1)
|
|
||||||
y1 = np.clip(y1, 0, h - 1)
|
|
||||||
|
|
||||||
if x.ndim == 2:
|
|
||||||
x = interp2d(xv, yv, x)(x1, y1)
|
|
||||||
if x.ndim == 3:
|
|
||||||
for i in range(x.shape[-1]):
|
|
||||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def blur(x, k):
|
|
||||||
"""
|
|
||||||
x: image, NxcxHxW
|
|
||||||
k: kernel, Nx1xhxw
|
|
||||||
"""
|
|
||||||
n, c = x.shape[:2]
|
|
||||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
|
||||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate")
|
|
||||||
k = k.repeat(1, c, 1, 1)
|
|
||||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
|
||||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
|
||||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
|
||||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def gen_kernel(
|
|
||||||
k_size=np.array([15, 15]),
|
|
||||||
scale_factor=np.array([4, 4]),
|
|
||||||
min_var=0.6,
|
|
||||||
max_var=10.0,
|
|
||||||
noise_level=0,
|
|
||||||
):
|
|
||||||
""" "
|
|
||||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
|
||||||
# Kai Zhang
|
|
||||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
|
||||||
# max_var = 2.5 * sf
|
|
||||||
"""
|
|
||||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
|
||||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
|
||||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
|
||||||
theta = np.random.rand() * np.pi # random theta
|
|
||||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
|
||||||
|
|
||||||
# Set COV matrix using Lambdas and Theta
|
|
||||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
|
||||||
Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
|
||||||
SIGMA = Q @ LAMBDA @ Q.T
|
|
||||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
|
||||||
|
|
||||||
# Set expectation position (shifting kernel for aligned image)
|
|
||||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
|
||||||
MU = MU[None, None, :, None]
|
|
||||||
|
|
||||||
# Create meshgrid for Gaussian
|
|
||||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
|
||||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
|
||||||
|
|
||||||
# Calcualte Gaussian for every pixel of the kernel
|
|
||||||
ZZ = Z - MU
|
|
||||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
|
||||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
|
||||||
|
|
||||||
# shift the kernel so it will be centered
|
|
||||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
|
||||||
|
|
||||||
# Normalize the kernel and return
|
|
||||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
|
||||||
kernel = raw_kernel / np.sum(raw_kernel)
|
|
||||||
return kernel
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial_gaussian(hsize, sigma):
|
|
||||||
hsize = [hsize, hsize]
|
|
||||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
|
||||||
std = sigma
|
|
||||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
|
||||||
arg = -(x * x + y * y) / (2 * std * std)
|
|
||||||
h = np.exp(arg)
|
|
||||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
|
||||||
sumh = h.sum()
|
|
||||||
if sumh != 0:
|
|
||||||
h = h / sumh
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial_laplacian(alpha):
|
|
||||||
alpha = max([0, min([alpha, 1])])
|
|
||||||
h1 = alpha / (alpha + 1)
|
|
||||||
h2 = (1 - alpha) / (alpha + 1)
|
|
||||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
|
||||||
h = np.array(h)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial(filter_type, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
python code from:
|
|
||||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
|
||||||
"""
|
|
||||||
if filter_type == "gaussian":
|
|
||||||
return fspecial_gaussian(*args, **kwargs)
|
|
||||||
if filter_type == "laplacian":
|
|
||||||
return fspecial_laplacian(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# degradation models
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def bicubic_degradation(x, sf=3):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
bicubicly downsampled LR image
|
|
||||||
"""
|
|
||||||
x = util.imresize_np(x, scale=1 / sf)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def srmd_degradation(x, k, sf=3):
|
|
||||||
"""blur + bicubic downsampling
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
Reference:
|
|
||||||
@inproceedings{zhang2018learning,
|
|
||||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
|
||||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
|
||||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
|
||||||
pages={3262--3271},
|
|
||||||
year={2018}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror'
|
|
||||||
x = bicubic_degradation(x, sf=sf)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def dpsr_degradation(x, k, sf=3):
|
|
||||||
"""bicubic downsampling + blur
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
Reference:
|
|
||||||
@inproceedings{zhang2019deep,
|
|
||||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
|
||||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
|
||||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
|
||||||
pages={1671--1681},
|
|
||||||
year={2019}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
x = bicubic_degradation(x, sf=sf)
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def classical_degradation(x, k, sf=3):
|
|
||||||
"""blur + downsampling
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]/[0, 255]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
"""
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
|
||||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
|
||||||
st = 0
|
|
||||||
return x[st::sf, st::sf, ...]
|
|
||||||
|
|
||||||
|
|
||||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
|
||||||
"""USM sharpening. borrowed from real-ESRGAN
|
|
||||||
Input image: I; Blurry image: B.
|
|
||||||
1. K = I + weight * (I - B)
|
|
||||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
|
||||||
3. Blur mask:
|
|
||||||
4. Out = Mask * K + (1 - Mask) * I
|
|
||||||
Args:
|
|
||||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
|
||||||
weight (float): Sharp weight. Default: 1.
|
|
||||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
|
||||||
threshold (int):
|
|
||||||
"""
|
|
||||||
if radius % 2 == 0:
|
|
||||||
radius += 1
|
|
||||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
|
||||||
residual = img - blur
|
|
||||||
mask = np.abs(residual) * 255 > threshold
|
|
||||||
mask = mask.astype("float32")
|
|
||||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
|
||||||
|
|
||||||
K = img + weight * residual
|
|
||||||
K = np.clip(K, 0, 1)
|
|
||||||
return soft_mask * K + (1 - soft_mask) * img
|
|
||||||
|
|
||||||
|
|
||||||
def add_blur(img, sf=4):
|
|
||||||
wd2 = 4.0 + sf
|
|
||||||
wd = 2.0 + 0.2 * sf
|
|
||||||
if random.random() < 0.5:
|
|
||||||
l1 = wd2 * random.random()
|
|
||||||
l2 = wd2 * random.random()
|
|
||||||
k = anisotropic_Gaussian(
|
|
||||||
ksize=2 * random.randint(2, 11) + 3,
|
|
||||||
theta=random.random() * np.pi,
|
|
||||||
l1=l1,
|
|
||||||
l2=l2,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k = fspecial("gaussian", 2 * random.randint(2, 11) + 3, wd * random.random())
|
|
||||||
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror")
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_resize(img, sf=4):
|
|
||||||
rnum = np.random.rand()
|
|
||||||
if rnum > 0.8: # up
|
|
||||||
sf1 = random.uniform(1, 2)
|
|
||||||
elif rnum < 0.7: # down
|
|
||||||
sf1 = random.uniform(0.5 / sf, 1)
|
|
||||||
else:
|
|
||||||
sf1 = 1.0
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
# noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
# rnum = np.random.rand()
|
|
||||||
# if rnum > 0.6: # add color Gaussian noise
|
|
||||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
|
||||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
# else: # add noise
|
|
||||||
# L = noise_level2 / 255.
|
|
||||||
# D = np.diag(np.random.rand(3))
|
|
||||||
# U = orth(np.random.rand(3, 3))
|
|
||||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
# img = np.clip(img, 0.0, 1.0)
|
|
||||||
# return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
rnum = np.random.rand()
|
|
||||||
if rnum > 0.6: # add color Gaussian noise
|
|
||||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
|
||||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
else: # add noise
|
|
||||||
L = noise_level2 / 255.0
|
|
||||||
D = np.diag(np.random.rand(3))
|
|
||||||
U = orth(np.random.rand(3, 3))
|
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
rnum = random.random()
|
|
||||||
if rnum > 0.6:
|
|
||||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
elif rnum < 0.4:
|
|
||||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
else:
|
|
||||||
L = noise_level2 / 255.0
|
|
||||||
D = np.diag(np.random.rand(3))
|
|
||||||
U = orth(np.random.rand(3, 3))
|
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_Poisson_noise(img):
|
|
||||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
|
|
||||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
|
||||||
if random.random() < 0.5:
|
|
||||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
|
||||||
else:
|
|
||||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
|
||||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
|
|
||||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
|
||||||
img += noise_gray[:, :, np.newaxis]
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_JPEG_noise(img):
|
|
||||||
quality_factor = random.randint(30, 95)
|
|
||||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
|
||||||
result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
|
||||||
img = cv2.imdecode(encimg, 1)
|
|
||||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
|
||||||
h, w = lq.shape[:2]
|
|
||||||
rnd_h = random.randint(0, h - lq_patchsize)
|
|
||||||
rnd_w = random.randint(0, w - lq_patchsize)
|
|
||||||
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
|
|
||||||
|
|
||||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
|
||||||
hq = hq[
|
|
||||||
rnd_h_H : rnd_h_H + lq_patchsize * sf,
|
|
||||||
rnd_w_H : rnd_w_H + lq_patchsize * sf,
|
|
||||||
:,
|
|
||||||
]
|
|
||||||
return lq, hq
|
|
||||||
|
|
||||||
|
|
||||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
|
||||||
"""
|
|
||||||
This is the degradation model of BSRGAN from the paper
|
|
||||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
|
||||||
----------
|
|
||||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
|
||||||
sf: scale factor
|
|
||||||
isp_model: camera ISP model
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
|
||||||
sf_ori = sf
|
|
||||||
|
|
||||||
h1, w1 = img.shape[:2]
|
|
||||||
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
|
|
||||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
|
||||||
raise ValueError(f"img size ({h1}X{w1}) is too small!")
|
|
||||||
|
|
||||||
hq = img.copy()
|
|
||||||
|
|
||||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
|
||||||
if np.random.rand() < 0.5:
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
img = util.imresize_np(img, 1 / 2, True)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
sf = 2
|
|
||||||
|
|
||||||
shuffle_order = random.sample(range(7), 7)
|
|
||||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
|
||||||
if idx1 > idx2: # keep downsample3 last
|
|
||||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
|
||||||
shuffle_order[idx2],
|
|
||||||
shuffle_order[idx1],
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
if i == 0:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
|
|
||||||
elif i == 1:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
|
|
||||||
elif i == 2:
|
|
||||||
a, b = img.shape[1], img.shape[0]
|
|
||||||
# downsample2
|
|
||||||
if random.random() < 0.75:
|
|
||||||
sf1 = random.uniform(1, 2 * sf)
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
|
||||||
k_shifted = shift_pixel(k, sf)
|
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
|
||||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
|
||||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 3:
|
|
||||||
# downsample3
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(1 / sf * a), int(1 / sf * b)),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 4:
|
|
||||||
# add Gaussian noise
|
|
||||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
|
||||||
|
|
||||||
elif i == 5:
|
|
||||||
# add JPEG noise
|
|
||||||
if random.random() < jpeg_prob:
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
elif i == 6:
|
|
||||||
# add processed camera sensor noise
|
|
||||||
if random.random() < isp_prob and isp_model is not None:
|
|
||||||
with torch.no_grad():
|
|
||||||
img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
# random crop
|
|
||||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
|
||||||
|
|
||||||
return img, hq
|
|
||||||
|
|
||||||
|
|
||||||
# todo no isp_model?
|
|
||||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
|
||||||
"""
|
|
||||||
This is the degradation model of BSRGAN from the paper
|
|
||||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
|
||||||
----------
|
|
||||||
sf: scale factor
|
|
||||||
isp_model: camera ISP model
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
image = util.uint2single(image)
|
|
||||||
jpeg_prob, scale2_prob = 0.9, 0.25
|
|
||||||
# isp_prob = 0.25 # uncomment with `if i== 6` block below
|
|
||||||
# sf_ori = sf # uncomment with `if i== 6` block below
|
|
||||||
|
|
||||||
h1, w1 = image.shape[:2]
|
|
||||||
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = image.shape[:2]
|
|
||||||
|
|
||||||
# hq = image.copy() # uncomment with `if i== 6` block below
|
|
||||||
|
|
||||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
|
||||||
if np.random.rand() < 0.5:
|
|
||||||
image = cv2.resize(
|
|
||||||
image,
|
|
||||||
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image = util.imresize_np(image, 1 / 2, True)
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
sf = 2
|
|
||||||
|
|
||||||
shuffle_order = random.sample(range(7), 7)
|
|
||||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
|
||||||
if idx1 > idx2: # keep downsample3 last
|
|
||||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
|
||||||
shuffle_order[idx2],
|
|
||||||
shuffle_order[idx1],
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
if i == 0:
|
|
||||||
image = add_blur(image, sf=sf)
|
|
||||||
|
|
||||||
elif i == 1:
|
|
||||||
image = add_blur(image, sf=sf)
|
|
||||||
|
|
||||||
elif i == 2:
|
|
||||||
a, b = image.shape[1], image.shape[0]
|
|
||||||
# downsample2
|
|
||||||
if random.random() < 0.75:
|
|
||||||
sf1 = random.uniform(1, 2 * sf)
|
|
||||||
image = cv2.resize(
|
|
||||||
image,
|
|
||||||
(
|
|
||||||
int(1 / sf1 * image.shape[1]),
|
|
||||||
int(1 / sf1 * image.shape[0]),
|
|
||||||
),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
|
||||||
k_shifted = shift_pixel(k, sf)
|
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
|
||||||
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
|
||||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 3:
|
|
||||||
# downsample3
|
|
||||||
image = cv2.resize(
|
|
||||||
image,
|
|
||||||
(int(1 / sf * a), int(1 / sf * b)),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 4:
|
|
||||||
# add Gaussian noise
|
|
||||||
image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
|
|
||||||
|
|
||||||
elif i == 5:
|
|
||||||
# add JPEG noise
|
|
||||||
if random.random() < jpeg_prob:
|
|
||||||
image = add_JPEG_noise(image)
|
|
||||||
|
|
||||||
# elif i == 6:
|
|
||||||
# # add processed camera sensor noise
|
|
||||||
# if random.random() < isp_prob and isp_model is not None:
|
|
||||||
# with torch.no_grad():
|
|
||||||
# img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
image = add_JPEG_noise(image)
|
|
||||||
image = util.single2uint(image)
|
|
||||||
example = {"image": image}
|
|
||||||
return example
|
|
||||||
|
|
||||||
|
|
||||||
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
|
|
||||||
def degradation_bsrgan_plus(
|
|
||||||
img,
|
|
||||||
sf=4,
|
|
||||||
shuffle_prob=0.5,
|
|
||||||
use_sharp=True,
|
|
||||||
lq_patchsize=64,
|
|
||||||
isp_model=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
This is an extended degradation model by combining
|
|
||||||
the degradation models of BSRGAN and Real-ESRGAN
|
|
||||||
----------
|
|
||||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
|
||||||
sf: scale factor
|
|
||||||
use_shuffle: the degradation shuffle
|
|
||||||
use_sharp: sharpening the img
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
|
|
||||||
h1, w1 = img.shape[:2]
|
|
||||||
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
|
|
||||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
|
||||||
raise ValueError(f"img size ({h1}X{w1}) is too small!")
|
|
||||||
|
|
||||||
if use_sharp:
|
|
||||||
img = add_sharpening(img)
|
|
||||||
hq = img.copy()
|
|
||||||
|
|
||||||
if random.random() < shuffle_prob:
|
|
||||||
shuffle_order = random.sample(range(13), 13)
|
|
||||||
else:
|
|
||||||
shuffle_order = list(range(13))
|
|
||||||
# local shuffle for noise, JPEG is always the last one
|
|
||||||
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
|
|
||||||
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
|
|
||||||
|
|
||||||
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
if i == 0:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
elif i == 1:
|
|
||||||
img = add_resize(img, sf=sf)
|
|
||||||
elif i == 2:
|
|
||||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
|
||||||
elif i == 3:
|
|
||||||
if random.random() < poisson_prob:
|
|
||||||
img = add_Poisson_noise(img)
|
|
||||||
elif i == 4:
|
|
||||||
if random.random() < speckle_prob:
|
|
||||||
img = add_speckle_noise(img)
|
|
||||||
elif i == 5:
|
|
||||||
if random.random() < isp_prob and isp_model is not None:
|
|
||||||
with torch.no_grad():
|
|
||||||
img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
elif i == 6:
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
elif i == 7:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
elif i == 8:
|
|
||||||
img = add_resize(img, sf=sf)
|
|
||||||
elif i == 9:
|
|
||||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
|
||||||
elif i == 10:
|
|
||||||
if random.random() < poisson_prob:
|
|
||||||
img = add_Poisson_noise(img)
|
|
||||||
elif i == 11:
|
|
||||||
if random.random() < speckle_prob:
|
|
||||||
img = add_speckle_noise(img)
|
|
||||||
elif i == 12:
|
|
||||||
if random.random() < isp_prob and isp_model is not None:
|
|
||||||
with torch.no_grad():
|
|
||||||
img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
else:
|
|
||||||
print("check the shuffle!")
|
|
||||||
|
|
||||||
# resize to desired size
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
# random crop
|
|
||||||
img, hq = random_crop(img, hq, sf, lq_patchsize)
|
|
||||||
|
|
||||||
return img, hq
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("hey")
|
|
||||||
img = util.imread_uint("utils/test.png", 3)
|
|
||||||
print(img)
|
|
||||||
img = util.uint2single(img)
|
|
||||||
print(img)
|
|
||||||
img = img[:448, :448]
|
|
||||||
h = img.shape[0] // 4
|
|
||||||
print("resizing to", h)
|
|
||||||
sf = 4
|
|
||||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
|
||||||
for i in range(20):
|
|
||||||
print(i)
|
|
||||||
img_lq = deg_fn(img)
|
|
||||||
print(img_lq)
|
|
||||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
|
|
||||||
print(img_lq.shape)
|
|
||||||
print("bicubic", img_lq_bicubic.shape)
|
|
||||||
# print(img_hq.shape)
|
|
||||||
lq_nearest = cv2.resize(
|
|
||||||
util.single2uint(img_lq),
|
|
||||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
|
||||||
interpolation=0,
|
|
||||||
)
|
|
||||||
lq_bicubic_nearest = cv2.resize(
|
|
||||||
util.single2uint(img_lq_bicubic),
|
|
||||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
|
||||||
interpolation=0,
|
|
||||||
)
|
|
||||||
# img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
|
||||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest], axis=1)
|
|
||||||
util.imsave(img_concat, str(i) + ".png")
|
|
@ -1,704 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
import random
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import albumentations
|
|
||||||
import cv2
|
|
||||||
import ldm.modules.image_degradation.utils_image as util
|
|
||||||
import numpy as np
|
|
||||||
import scipy
|
|
||||||
import scipy.stats as ss
|
|
||||||
import torch
|
|
||||||
from scipy import ndimage
|
|
||||||
from scipy.interpolate import interp2d
|
|
||||||
from scipy.linalg import orth
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# Super-Resolution
|
|
||||||
# --------------------------------------------
|
|
||||||
#
|
|
||||||
# Kai Zhang (cskaizhang@gmail.com)
|
|
||||||
# https://github.com/cszn
|
|
||||||
# From 2019/03--2021/08
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def modcrop_np(img, sf):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
img: numpy image, WxH or WxHxC
|
|
||||||
sf: scale factor
|
|
||||||
Return:
|
|
||||||
cropped image
|
|
||||||
"""
|
|
||||||
w, h = img.shape[:2]
|
|
||||||
im = np.copy(img)
|
|
||||||
return im[: w - w % sf, : h - h % sf, ...]
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# anisotropic Gaussian kernels
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def analytic_kernel(k):
|
|
||||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
|
||||||
k_size = k.shape[0]
|
|
||||||
# Calculate the big kernels size
|
|
||||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
|
||||||
# Loop over the small kernel to fill the big one
|
|
||||||
for r in range(k_size):
|
|
||||||
for c in range(k_size):
|
|
||||||
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k
|
|
||||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
|
||||||
crop = k_size // 2
|
|
||||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
|
||||||
# Normalize to 1
|
|
||||||
return cropped_big_k / cropped_big_k.sum()
|
|
||||||
|
|
||||||
|
|
||||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
|
||||||
"""generate an anisotropic Gaussian kernel
|
|
||||||
Args:
|
|
||||||
ksize : e.g., 15, kernel size
|
|
||||||
theta : [0, pi], rotation angle range
|
|
||||||
l1 : [0.1,50], scaling of eigenvalues
|
|
||||||
l2 : [0.1,l1], scaling of eigenvalues
|
|
||||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
|
||||||
Returns:
|
|
||||||
k : kernel
|
|
||||||
"""
|
|
||||||
|
|
||||||
v = np.dot(
|
|
||||||
np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]),
|
|
||||||
np.array([1.0, 0.0]),
|
|
||||||
)
|
|
||||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
|
||||||
D = np.array([[l1, 0], [0, l2]])
|
|
||||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
|
||||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
|
||||||
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
def gm_blur_kernel(mean, cov, size=15):
|
|
||||||
center = size / 2.0 + 0.5
|
|
||||||
k = np.zeros([size, size])
|
|
||||||
for y in range(size):
|
|
||||||
for x in range(size):
|
|
||||||
cy = y - center + 1
|
|
||||||
cx = x - center + 1
|
|
||||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
|
||||||
|
|
||||||
k = k / np.sum(k)
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
def shift_pixel(x, sf, upper_left=True):
|
|
||||||
"""shift pixel for super-resolution with different scale factors
|
|
||||||
Args:
|
|
||||||
x: WxHxC or WxH
|
|
||||||
sf: scale factor
|
|
||||||
upper_left: shift direction
|
|
||||||
"""
|
|
||||||
h, w = x.shape[:2]
|
|
||||||
shift = (sf - 1) * 0.5
|
|
||||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
|
||||||
if upper_left:
|
|
||||||
x1 = xv + shift
|
|
||||||
y1 = yv + shift
|
|
||||||
else:
|
|
||||||
x1 = xv - shift
|
|
||||||
y1 = yv - shift
|
|
||||||
|
|
||||||
x1 = np.clip(x1, 0, w - 1)
|
|
||||||
y1 = np.clip(y1, 0, h - 1)
|
|
||||||
|
|
||||||
if x.ndim == 2:
|
|
||||||
x = interp2d(xv, yv, x)(x1, y1)
|
|
||||||
if x.ndim == 3:
|
|
||||||
for i in range(x.shape[-1]):
|
|
||||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def blur(x, k):
|
|
||||||
"""
|
|
||||||
x: image, NxcxHxW
|
|
||||||
k: kernel, Nx1xhxw
|
|
||||||
"""
|
|
||||||
n, c = x.shape[:2]
|
|
||||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
|
||||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate")
|
|
||||||
k = k.repeat(1, c, 1, 1)
|
|
||||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
|
||||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
|
||||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
|
||||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def gen_kernel(
|
|
||||||
k_size=np.array([15, 15]),
|
|
||||||
scale_factor=np.array([4, 4]),
|
|
||||||
min_var=0.6,
|
|
||||||
max_var=10.0,
|
|
||||||
noise_level=0,
|
|
||||||
):
|
|
||||||
""" "
|
|
||||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
|
||||||
# Kai Zhang
|
|
||||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
|
||||||
# max_var = 2.5 * sf
|
|
||||||
"""
|
|
||||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
|
||||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
|
||||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
|
||||||
theta = np.random.rand() * np.pi # random theta
|
|
||||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
|
||||||
|
|
||||||
# Set COV matrix using Lambdas and Theta
|
|
||||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
|
||||||
Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
|
||||||
SIGMA = Q @ LAMBDA @ Q.T
|
|
||||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
|
||||||
|
|
||||||
# Set expectation position (shifting kernel for aligned image)
|
|
||||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
|
||||||
MU = MU[None, None, :, None]
|
|
||||||
|
|
||||||
# Create meshgrid for Gaussian
|
|
||||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
|
||||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
|
||||||
|
|
||||||
# Calcualte Gaussian for every pixel of the kernel
|
|
||||||
ZZ = Z - MU
|
|
||||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
|
||||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
|
||||||
|
|
||||||
# shift the kernel so it will be centered
|
|
||||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
|
||||||
|
|
||||||
# Normalize the kernel and return
|
|
||||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
|
||||||
kernel = raw_kernel / np.sum(raw_kernel)
|
|
||||||
return kernel
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial_gaussian(hsize, sigma):
|
|
||||||
hsize = [hsize, hsize]
|
|
||||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
|
||||||
std = sigma
|
|
||||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
|
||||||
arg = -(x * x + y * y) / (2 * std * std)
|
|
||||||
h = np.exp(arg)
|
|
||||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
|
||||||
sumh = h.sum()
|
|
||||||
if sumh != 0:
|
|
||||||
h = h / sumh
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial_laplacian(alpha):
|
|
||||||
alpha = max([0, min([alpha, 1])])
|
|
||||||
h1 = alpha / (alpha + 1)
|
|
||||||
h2 = (1 - alpha) / (alpha + 1)
|
|
||||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
|
||||||
h = np.array(h)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial(filter_type, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
python code from:
|
|
||||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
|
||||||
"""
|
|
||||||
if filter_type == "gaussian":
|
|
||||||
return fspecial_gaussian(*args, **kwargs)
|
|
||||||
if filter_type == "laplacian":
|
|
||||||
return fspecial_laplacian(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# degradation models
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def bicubic_degradation(x, sf=3):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
bicubicly downsampled LR image
|
|
||||||
"""
|
|
||||||
x = util.imresize_np(x, scale=1 / sf)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def srmd_degradation(x, k, sf=3):
|
|
||||||
"""blur + bicubic downsampling
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
Reference:
|
|
||||||
@inproceedings{zhang2018learning,
|
|
||||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
|
||||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
|
||||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
|
||||||
pages={3262--3271},
|
|
||||||
year={2018}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror'
|
|
||||||
x = bicubic_degradation(x, sf=sf)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def dpsr_degradation(x, k, sf=3):
|
|
||||||
"""bicubic downsampling + blur
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
Reference:
|
|
||||||
@inproceedings{zhang2019deep,
|
|
||||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
|
||||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
|
||||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
|
||||||
pages={1671--1681},
|
|
||||||
year={2019}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
x = bicubic_degradation(x, sf=sf)
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def classical_degradation(x, k, sf=3):
|
|
||||||
"""blur + downsampling
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]/[0, 255]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
"""
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
|
||||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
|
||||||
st = 0
|
|
||||||
return x[st::sf, st::sf, ...]
|
|
||||||
|
|
||||||
|
|
||||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
|
||||||
"""USM sharpening. borrowed from real-ESRGAN
|
|
||||||
Input image: I; Blurry image: B.
|
|
||||||
1. K = I + weight * (I - B)
|
|
||||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
|
||||||
3. Blur mask:
|
|
||||||
4. Out = Mask * K + (1 - Mask) * I
|
|
||||||
Args:
|
|
||||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
|
||||||
weight (float): Sharp weight. Default: 1.
|
|
||||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
|
||||||
threshold (int):
|
|
||||||
"""
|
|
||||||
if radius % 2 == 0:
|
|
||||||
radius += 1
|
|
||||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
|
||||||
residual = img - blur
|
|
||||||
mask = np.abs(residual) * 255 > threshold
|
|
||||||
mask = mask.astype("float32")
|
|
||||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
|
||||||
|
|
||||||
K = img + weight * residual
|
|
||||||
K = np.clip(K, 0, 1)
|
|
||||||
return soft_mask * K + (1 - soft_mask) * img
|
|
||||||
|
|
||||||
|
|
||||||
def add_blur(img, sf=4):
|
|
||||||
wd2 = 4.0 + sf
|
|
||||||
wd = 2.0 + 0.2 * sf
|
|
||||||
|
|
||||||
wd2 = wd2 / 4
|
|
||||||
wd = wd / 4
|
|
||||||
|
|
||||||
if random.random() < 0.5:
|
|
||||||
l1 = wd2 * random.random()
|
|
||||||
l2 = wd2 * random.random()
|
|
||||||
k = anisotropic_Gaussian(
|
|
||||||
ksize=random.randint(2, 11) + 3,
|
|
||||||
theta=random.random() * np.pi,
|
|
||||||
l1=l1,
|
|
||||||
l2=l2,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k = fspecial("gaussian", random.randint(2, 4) + 3, wd * random.random())
|
|
||||||
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror")
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_resize(img, sf=4):
|
|
||||||
rnum = np.random.rand()
|
|
||||||
if rnum > 0.8: # up
|
|
||||||
sf1 = random.uniform(1, 2)
|
|
||||||
elif rnum < 0.7: # down
|
|
||||||
sf1 = random.uniform(0.5 / sf, 1)
|
|
||||||
else:
|
|
||||||
sf1 = 1.0
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
# noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
# rnum = np.random.rand()
|
|
||||||
# if rnum > 0.6: # add color Gaussian noise
|
|
||||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
|
||||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
# else: # add noise
|
|
||||||
# L = noise_level2 / 255.
|
|
||||||
# D = np.diag(np.random.rand(3))
|
|
||||||
# U = orth(np.random.rand(3, 3))
|
|
||||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
# img = np.clip(img, 0.0, 1.0)
|
|
||||||
# return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
rnum = np.random.rand()
|
|
||||||
if rnum > 0.6: # add color Gaussian noise
|
|
||||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
|
||||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
else: # add noise
|
|
||||||
L = noise_level2 / 255.0
|
|
||||||
D = np.diag(np.random.rand(3))
|
|
||||||
U = orth(np.random.rand(3, 3))
|
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
rnum = random.random()
|
|
||||||
if rnum > 0.6:
|
|
||||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
elif rnum < 0.4:
|
|
||||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
else:
|
|
||||||
L = noise_level2 / 255.0
|
|
||||||
D = np.diag(np.random.rand(3))
|
|
||||||
U = orth(np.random.rand(3, 3))
|
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_Poisson_noise(img):
|
|
||||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
|
|
||||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
|
||||||
if random.random() < 0.5:
|
|
||||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
|
||||||
else:
|
|
||||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
|
||||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
|
|
||||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
|
||||||
img += noise_gray[:, :, np.newaxis]
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_JPEG_noise(img):
|
|
||||||
quality_factor = random.randint(80, 95)
|
|
||||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
|
||||||
result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
|
||||||
img = cv2.imdecode(encimg, 1)
|
|
||||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
|
||||||
h, w = lq.shape[:2]
|
|
||||||
rnd_h = random.randint(0, h - lq_patchsize)
|
|
||||||
rnd_w = random.randint(0, w - lq_patchsize)
|
|
||||||
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
|
|
||||||
|
|
||||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
|
||||||
hq = hq[
|
|
||||||
rnd_h_H : rnd_h_H + lq_patchsize * sf,
|
|
||||||
rnd_w_H : rnd_w_H + lq_patchsize * sf,
|
|
||||||
:,
|
|
||||||
]
|
|
||||||
return lq, hq
|
|
||||||
|
|
||||||
|
|
||||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
|
||||||
"""
|
|
||||||
This is the degradation model of BSRGAN from the paper
|
|
||||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
|
||||||
----------
|
|
||||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
|
||||||
sf: scale factor
|
|
||||||
isp_model: camera ISP model
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
|
||||||
sf_ori = sf
|
|
||||||
|
|
||||||
h1, w1 = img.shape[:2]
|
|
||||||
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
|
|
||||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
|
||||||
raise ValueError(f"img size ({h1}X{w1}) is too small!")
|
|
||||||
|
|
||||||
hq = img.copy()
|
|
||||||
|
|
||||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
|
||||||
if np.random.rand() < 0.5:
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
img = util.imresize_np(img, 1 / 2, True)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
sf = 2
|
|
||||||
|
|
||||||
shuffle_order = random.sample(range(7), 7)
|
|
||||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
|
||||||
if idx1 > idx2: # keep downsample3 last
|
|
||||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
|
||||||
shuffle_order[idx2],
|
|
||||||
shuffle_order[idx1],
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
if i == 0:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
|
|
||||||
elif i == 1:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
|
|
||||||
elif i == 2:
|
|
||||||
a, b = img.shape[1], img.shape[0]
|
|
||||||
# downsample2
|
|
||||||
if random.random() < 0.75:
|
|
||||||
sf1 = random.uniform(1, 2 * sf)
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
|
||||||
k_shifted = shift_pixel(k, sf)
|
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
|
||||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
|
||||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 3:
|
|
||||||
# downsample3
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(1 / sf * a), int(1 / sf * b)),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 4:
|
|
||||||
# add Gaussian noise
|
|
||||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
|
|
||||||
|
|
||||||
elif i == 5:
|
|
||||||
# add JPEG noise
|
|
||||||
if random.random() < jpeg_prob:
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
elif i == 6:
|
|
||||||
# add processed camera sensor noise
|
|
||||||
if random.random() < isp_prob and isp_model is not None:
|
|
||||||
with torch.no_grad():
|
|
||||||
img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
# random crop
|
|
||||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
|
||||||
|
|
||||||
return img, hq
|
|
||||||
|
|
||||||
|
|
||||||
# todo no isp_model?
|
|
||||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
|
||||||
"""
|
|
||||||
This is the degradation model of BSRGAN from the paper
|
|
||||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
|
||||||
----------
|
|
||||||
sf: scale factor
|
|
||||||
isp_model: camera ISP model
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
image = util.uint2single(image)
|
|
||||||
jpeg_prob, scale2_prob = 0.9, 0.25
|
|
||||||
# isp_prob = 0.25 # uncomment with `if i== 6` block below
|
|
||||||
# sf_ori = sf # uncomment with `if i== 6` block below
|
|
||||||
|
|
||||||
h1, w1 = image.shape[:2]
|
|
||||||
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = image.shape[:2]
|
|
||||||
|
|
||||||
# hq = image.copy() # uncomment with `if i== 6` block below
|
|
||||||
|
|
||||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
|
||||||
if np.random.rand() < 0.5:
|
|
||||||
image = cv2.resize(
|
|
||||||
image,
|
|
||||||
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image = util.imresize_np(image, 1 / 2, True)
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
sf = 2
|
|
||||||
|
|
||||||
shuffle_order = random.sample(range(7), 7)
|
|
||||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
|
||||||
if idx1 > idx2: # keep downsample3 last
|
|
||||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
|
||||||
shuffle_order[idx2],
|
|
||||||
shuffle_order[idx1],
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
if i == 0:
|
|
||||||
image = add_blur(image, sf=sf)
|
|
||||||
|
|
||||||
# elif i == 1:
|
|
||||||
# image = add_blur(image, sf=sf)
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif i == 2:
|
|
||||||
a, b = image.shape[1], image.shape[0]
|
|
||||||
# downsample2
|
|
||||||
if random.random() < 0.8:
|
|
||||||
sf1 = random.uniform(1, 2 * sf)
|
|
||||||
image = cv2.resize(
|
|
||||||
image,
|
|
||||||
(
|
|
||||||
int(1 / sf1 * image.shape[1]),
|
|
||||||
int(1 / sf1 * image.shape[0]),
|
|
||||||
),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
|
||||||
k_shifted = shift_pixel(k, sf)
|
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
|
||||||
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
|
||||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
|
||||||
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 3:
|
|
||||||
# downsample3
|
|
||||||
image = cv2.resize(
|
|
||||||
image,
|
|
||||||
(int(1 / sf * a), int(1 / sf * b)),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 4:
|
|
||||||
# add Gaussian noise
|
|
||||||
image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
|
|
||||||
|
|
||||||
elif i == 5:
|
|
||||||
# add JPEG noise
|
|
||||||
if random.random() < jpeg_prob:
|
|
||||||
image = add_JPEG_noise(image)
|
|
||||||
#
|
|
||||||
# elif i == 6:
|
|
||||||
# # add processed camera sensor noise
|
|
||||||
# if random.random() < isp_prob and isp_model is not None:
|
|
||||||
# with torch.no_grad():
|
|
||||||
# img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
image = add_JPEG_noise(image)
|
|
||||||
image = util.single2uint(image)
|
|
||||||
example = {"image": image}
|
|
||||||
return example
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("hey")
|
|
||||||
img = util.imread_uint("utils/test.png", 3)
|
|
||||||
img = img[:448, :448]
|
|
||||||
h = img.shape[0] // 4
|
|
||||||
print("resizing to", h)
|
|
||||||
sf = 4
|
|
||||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
|
||||||
for i in range(20):
|
|
||||||
print(i)
|
|
||||||
img_hq = img
|
|
||||||
img_lq = deg_fn(img)["image"]
|
|
||||||
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
|
|
||||||
print(img_lq)
|
|
||||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[
|
|
||||||
"image"
|
|
||||||
]
|
|
||||||
print(img_lq.shape)
|
|
||||||
print("bicubic", img_lq_bicubic.shape)
|
|
||||||
print(img_hq.shape)
|
|
||||||
lq_nearest = cv2.resize(
|
|
||||||
util.single2uint(img_lq),
|
|
||||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
|
||||||
interpolation=0,
|
|
||||||
)
|
|
||||||
lq_bicubic_nearest = cv2.resize(
|
|
||||||
util.single2uint(img_lq_bicubic),
|
|
||||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
|
||||||
interpolation=0,
|
|
||||||
)
|
|
||||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
|
||||||
util.imsave(img_concat, str(i) + ".png")
|
|
Binary file not shown.
Before Width: | Height: | Size: 431 KiB |
@ -1,968 +0,0 @@
|
|||||||
import math
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torchvision.utils import make_grid
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# Kai Zhang (github: https://github.com/cszn)
|
|
||||||
# 03/Mar/2019
|
|
||||||
# --------------------------------------------
|
|
||||||
# https://github.com/twhui/SRGAN-pyTorch
|
|
||||||
# https://github.com/xinntao/BasicSR
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
IMG_EXTENSIONS = [
|
|
||||||
".jpg",
|
|
||||||
".JPG",
|
|
||||||
".jpeg",
|
|
||||||
".JPEG",
|
|
||||||
".png",
|
|
||||||
".PNG",
|
|
||||||
".ppm",
|
|
||||||
".PPM",
|
|
||||||
".bmp",
|
|
||||||
".BMP",
|
|
||||||
".tif",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def is_image_file(filename):
|
|
||||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
|
||||||
|
|
||||||
|
|
||||||
def get_timestamp():
|
|
||||||
return datetime.now().strftime("%y%m%d-%H%M%S")
|
|
||||||
|
|
||||||
|
|
||||||
def imshow(x, title=None, cbar=False, figsize=None):
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
plt.figure(figsize=figsize)
|
|
||||||
plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray")
|
|
||||||
if title:
|
|
||||||
plt.title(title)
|
|
||||||
if cbar:
|
|
||||||
plt.colorbar()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def surf(Z, cmap="rainbow", figsize=None):
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
plt.figure(figsize=figsize)
|
|
||||||
ax3 = plt.axes(projection="3d")
|
|
||||||
|
|
||||||
w, h = Z.shape[:2]
|
|
||||||
xx = np.arange(0, w, 1)
|
|
||||||
yy = np.arange(0, h, 1)
|
|
||||||
X, Y = np.meshgrid(xx, yy)
|
|
||||||
ax3.plot_surface(X, Y, Z, cmap=cmap)
|
|
||||||
# ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# get image pathes
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_paths(dataroot):
|
|
||||||
paths = None # return None if dataroot is None
|
|
||||||
if dataroot is not None:
|
|
||||||
paths = sorted(_get_paths_from_images(dataroot))
|
|
||||||
return paths
|
|
||||||
|
|
||||||
|
|
||||||
def _get_paths_from_images(path):
|
|
||||||
assert os.path.isdir(path), "{:s} is not a valid directory".format(path)
|
|
||||||
images = []
|
|
||||||
for dirpath, _, fnames in sorted(os.walk(path, followlinks=True)):
|
|
||||||
for fname in sorted(fnames):
|
|
||||||
if is_image_file(fname):
|
|
||||||
img_path = os.path.join(dirpath, fname)
|
|
||||||
images.append(img_path)
|
|
||||||
assert images, "{:s} has no valid image file".format(path)
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# split large images into small images
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
|
|
||||||
w, h = img.shape[:2]
|
|
||||||
patches = []
|
|
||||||
if w > p_max and h > p_max:
|
|
||||||
w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int))
|
|
||||||
h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int))
|
|
||||||
w1.append(w - p_size)
|
|
||||||
h1.append(h - p_size)
|
|
||||||
# print(w1)
|
|
||||||
# print(h1)
|
|
||||||
for i in w1:
|
|
||||||
for j in h1:
|
|
||||||
patches.append(img[i : i + p_size, j : j + p_size, :])
|
|
||||||
else:
|
|
||||||
patches.append(img)
|
|
||||||
|
|
||||||
return patches
|
|
||||||
|
|
||||||
|
|
||||||
def imssave(imgs, img_path):
|
|
||||||
"""
|
|
||||||
imgs: list, N images of size WxHxC
|
|
||||||
"""
|
|
||||||
img_name, ext = os.path.splitext(os.path.basename(img_path))
|
|
||||||
|
|
||||||
for i, img in enumerate(imgs):
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = img[:, :, [2, 1, 0]]
|
|
||||||
new_path = os.path.join(
|
|
||||||
os.path.dirname(img_path),
|
|
||||||
img_name + str("_s{:04d}".format(i)) + ".png",
|
|
||||||
)
|
|
||||||
cv2.imwrite(new_path, img)
|
|
||||||
|
|
||||||
|
|
||||||
def split_imageset(
|
|
||||||
original_dataroot,
|
|
||||||
taget_dataroot,
|
|
||||||
n_channels=3,
|
|
||||||
p_size=800,
|
|
||||||
p_overlap=96,
|
|
||||||
p_max=1000,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
|
|
||||||
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
|
|
||||||
will be splitted.
|
|
||||||
Args:
|
|
||||||
original_dataroot:
|
|
||||||
taget_dataroot:
|
|
||||||
p_size: size of small images
|
|
||||||
p_overlap: patch size in training is a good choice
|
|
||||||
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
|
|
||||||
"""
|
|
||||||
paths = get_image_paths(original_dataroot)
|
|
||||||
for img_path in paths:
|
|
||||||
# img_name, ext = os.path.splitext(os.path.basename(img_path))
|
|
||||||
img = imread_uint(img_path, n_channels=n_channels)
|
|
||||||
patches = patches_from_image(img, p_size, p_overlap, p_max)
|
|
||||||
imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
|
|
||||||
# if original_dataroot == taget_dataroot:
|
|
||||||
# del img_path
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# makedir
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def mkdir(path):
|
|
||||||
if not os.path.exists(path):
|
|
||||||
os.makedirs(path)
|
|
||||||
|
|
||||||
|
|
||||||
def mkdirs(paths):
|
|
||||||
if isinstance(paths, str):
|
|
||||||
mkdir(paths)
|
|
||||||
else:
|
|
||||||
for path in paths:
|
|
||||||
mkdir(path)
|
|
||||||
|
|
||||||
|
|
||||||
def mkdir_and_rename(path):
|
|
||||||
if os.path.exists(path):
|
|
||||||
new_name = path + "_archived_" + get_timestamp()
|
|
||||||
logger.error("Path already exists. Rename it to [{:s}]".format(new_name))
|
|
||||||
os.replace(path, new_name)
|
|
||||||
os.makedirs(path)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# read image from path
|
|
||||||
# opencv is fast, but read BGR numpy image
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# get uint8 image of size HxWxn_channles (RGB)
|
|
||||||
# --------------------------------------------
|
|
||||||
def imread_uint(path, n_channels=3):
|
|
||||||
# input: path
|
|
||||||
# output: HxWx3(RGB or GGG), or HxWx1 (G)
|
|
||||||
if n_channels == 1:
|
|
||||||
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
|
|
||||||
img = np.expand_dims(img, axis=2) # HxWx1
|
|
||||||
elif n_channels == 3:
|
|
||||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
|
|
||||||
else:
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# matlab's imwrite
|
|
||||||
# --------------------------------------------
|
|
||||||
def imsave(img, img_path):
|
|
||||||
img = np.squeeze(img)
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = img[:, :, [2, 1, 0]]
|
|
||||||
cv2.imwrite(img_path, img)
|
|
||||||
|
|
||||||
|
|
||||||
def imwrite(img, img_path):
|
|
||||||
img = np.squeeze(img)
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = img[:, :, [2, 1, 0]]
|
|
||||||
cv2.imwrite(img_path, img)
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# get single image of size HxWxn_channles (BGR)
|
|
||||||
# --------------------------------------------
|
|
||||||
def read_img(path):
|
|
||||||
# read image by cv2
|
|
||||||
# return: Numpy float32, HWC, BGR, [0,1]
|
|
||||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
|
|
||||||
img = img.astype(np.float32) / 255.0
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = np.expand_dims(img, axis=2)
|
|
||||||
# some images have 4 channels
|
|
||||||
if img.shape[2] > 3:
|
|
||||||
img = img[:, :, :3]
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# image format conversion
|
|
||||||
# --------------------------------------------
|
|
||||||
# numpy(single) <---> numpy(unit)
|
|
||||||
# numpy(single) <---> tensor
|
|
||||||
# numpy(unit) <---> tensor
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# numpy(single) [0, 1] <---> numpy(unit)
|
|
||||||
# --------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def uint2single(img):
|
|
||||||
return np.float32(img / 255.0)
|
|
||||||
|
|
||||||
|
|
||||||
def single2uint(img):
|
|
||||||
return np.uint8((img.clip(0, 1) * 255.0).round())
|
|
||||||
|
|
||||||
|
|
||||||
def uint162single(img):
|
|
||||||
return np.float32(img / 65535.0)
|
|
||||||
|
|
||||||
|
|
||||||
def single2uint16(img):
|
|
||||||
return np.uint16((img.clip(0, 1) * 65535.0).round())
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# numpy(unit) (HxWxC or HxW) <---> tensor
|
|
||||||
# --------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
# convert uint to 4-dimensional torch tensor
|
|
||||||
def uint2tensor4(img):
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = np.expand_dims(img, axis=2)
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
# convert uint to 3-dimensional torch tensor
|
|
||||||
def uint2tensor3(img):
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = np.expand_dims(img, axis=2)
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0)
|
|
||||||
|
|
||||||
|
|
||||||
# convert 2/3/4-dimensional torch tensor to uint
|
|
||||||
def tensor2uint(img):
|
|
||||||
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = np.transpose(img, (1, 2, 0))
|
|
||||||
return np.uint8((img * 255.0).round())
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# numpy(single) (HxWxC) <---> tensor
|
|
||||||
# --------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
# convert single (HxWxC) to 3-dimensional torch tensor
|
|
||||||
def single2tensor3(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
|
|
||||||
|
|
||||||
|
|
||||||
# convert single (HxWxC) to 4-dimensional torch tensor
|
|
||||||
def single2tensor4(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
# convert torch tensor to single
|
|
||||||
def tensor2single(img):
|
|
||||||
img = img.data.squeeze().float().cpu().numpy()
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = np.transpose(img, (1, 2, 0))
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# convert torch tensor to single
|
|
||||||
def tensor2single3(img):
|
|
||||||
img = img.data.squeeze().float().cpu().numpy()
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = np.transpose(img, (1, 2, 0))
|
|
||||||
elif img.ndim == 2:
|
|
||||||
img = np.expand_dims(img, axis=2)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def single2tensor5(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
def single32tensor5(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
def single42tensor4(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
|
|
||||||
|
|
||||||
|
|
||||||
# from skimage.io import imread, imsave
|
|
||||||
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
|
||||||
"""
|
|
||||||
Converts a torch Tensor into an image Numpy array of BGR channel order
|
|
||||||
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
|
||||||
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
|
||||||
"""
|
|
||||||
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
|
|
||||||
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
|
|
||||||
n_dim = tensor.dim()
|
|
||||||
if n_dim == 4:
|
|
||||||
n_img = len(tensor)
|
|
||||||
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
|
|
||||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
|
||||||
elif n_dim == 3:
|
|
||||||
img_np = tensor.numpy()
|
|
||||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
|
||||||
elif n_dim == 2:
|
|
||||||
img_np = tensor.numpy()
|
|
||||||
else:
|
|
||||||
raise TypeError("Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(n_dim))
|
|
||||||
if out_type == np.uint8:
|
|
||||||
img_np = (img_np * 255.0).round()
|
|
||||||
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
|
|
||||||
return img_np.astype(out_type)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# Augmentation, flipe and/or rotate
|
|
||||||
# --------------------------------------------
|
|
||||||
# The following two are enough.
|
|
||||||
# (1) augmet_img: numpy image of WxHxC or WxH
|
|
||||||
# (2) augment_img_tensor4: tensor image 1xCxWxH
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def augment_img(img, mode=0):
|
|
||||||
"""Kai Zhang (github: https://github.com/cszn)"""
|
|
||||||
if mode == 0:
|
|
||||||
return img
|
|
||||||
elif mode == 1:
|
|
||||||
return np.flipud(np.rot90(img))
|
|
||||||
elif mode == 2:
|
|
||||||
return np.flipud(img)
|
|
||||||
elif mode == 3:
|
|
||||||
return np.rot90(img, k=3)
|
|
||||||
elif mode == 4:
|
|
||||||
return np.flipud(np.rot90(img, k=2))
|
|
||||||
elif mode == 5:
|
|
||||||
return np.rot90(img)
|
|
||||||
elif mode == 6:
|
|
||||||
return np.rot90(img, k=2)
|
|
||||||
elif mode == 7:
|
|
||||||
return np.flipud(np.rot90(img, k=3))
|
|
||||||
|
|
||||||
|
|
||||||
def augment_img_tensor4(img, mode=0):
|
|
||||||
"""Kai Zhang (github: https://github.com/cszn)"""
|
|
||||||
if mode == 0:
|
|
||||||
return img
|
|
||||||
elif mode == 1:
|
|
||||||
return img.rot90(1, [2, 3]).flip([2])
|
|
||||||
elif mode == 2:
|
|
||||||
return img.flip([2])
|
|
||||||
elif mode == 3:
|
|
||||||
return img.rot90(3, [2, 3])
|
|
||||||
elif mode == 4:
|
|
||||||
return img.rot90(2, [2, 3]).flip([2])
|
|
||||||
elif mode == 5:
|
|
||||||
return img.rot90(1, [2, 3])
|
|
||||||
elif mode == 6:
|
|
||||||
return img.rot90(2, [2, 3])
|
|
||||||
elif mode == 7:
|
|
||||||
return img.rot90(3, [2, 3]).flip([2])
|
|
||||||
|
|
||||||
|
|
||||||
def augment_img_tensor(img, mode=0):
|
|
||||||
"""Kai Zhang (github: https://github.com/cszn)"""
|
|
||||||
img_size = img.size()
|
|
||||||
img_np = img.data.cpu().numpy()
|
|
||||||
if len(img_size) == 3:
|
|
||||||
img_np = np.transpose(img_np, (1, 2, 0))
|
|
||||||
elif len(img_size) == 4:
|
|
||||||
img_np = np.transpose(img_np, (2, 3, 1, 0))
|
|
||||||
img_np = augment_img(img_np, mode=mode)
|
|
||||||
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
|
|
||||||
if len(img_size) == 3:
|
|
||||||
img_tensor = img_tensor.permute(2, 0, 1)
|
|
||||||
elif len(img_size) == 4:
|
|
||||||
img_tensor = img_tensor.permute(3, 2, 0, 1)
|
|
||||||
|
|
||||||
return img_tensor.type_as(img)
|
|
||||||
|
|
||||||
|
|
||||||
def augment_img_np3(img, mode=0):
|
|
||||||
if mode == 0:
|
|
||||||
return img
|
|
||||||
elif mode == 1:
|
|
||||||
return img.transpose(1, 0, 2)
|
|
||||||
elif mode == 2:
|
|
||||||
return img[::-1, :, :]
|
|
||||||
elif mode == 3:
|
|
||||||
img = img[::-1, :, :]
|
|
||||||
img = img.transpose(1, 0, 2)
|
|
||||||
return img
|
|
||||||
elif mode == 4:
|
|
||||||
return img[:, ::-1, :]
|
|
||||||
elif mode == 5:
|
|
||||||
img = img[:, ::-1, :]
|
|
||||||
img = img.transpose(1, 0, 2)
|
|
||||||
return img
|
|
||||||
elif mode == 6:
|
|
||||||
img = img[:, ::-1, :]
|
|
||||||
img = img[::-1, :, :]
|
|
||||||
return img
|
|
||||||
elif mode == 7:
|
|
||||||
img = img[:, ::-1, :]
|
|
||||||
img = img[::-1, :, :]
|
|
||||||
img = img.transpose(1, 0, 2)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def augment_imgs(img_list, hflip=True, rot=True):
|
|
||||||
# horizontal flip OR rotate
|
|
||||||
hflip = hflip and random.random() < 0.5
|
|
||||||
vflip = rot and random.random() < 0.5
|
|
||||||
rot90 = rot and random.random() < 0.5
|
|
||||||
|
|
||||||
def _augment(img):
|
|
||||||
if hflip:
|
|
||||||
img = img[:, ::-1, :]
|
|
||||||
if vflip:
|
|
||||||
img = img[::-1, :, :]
|
|
||||||
if rot90:
|
|
||||||
img = img.transpose(1, 0, 2)
|
|
||||||
return img
|
|
||||||
|
|
||||||
return [_augment(img) for img in img_list]
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# modcrop and shave
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def modcrop(img_in, scale):
|
|
||||||
# img_in: Numpy, HWC or HW
|
|
||||||
img = np.copy(img_in)
|
|
||||||
if img.ndim == 2:
|
|
||||||
H, W = img.shape
|
|
||||||
H_r, W_r = H % scale, W % scale
|
|
||||||
img = img[: H - H_r, : W - W_r]
|
|
||||||
elif img.ndim == 3:
|
|
||||||
H, W, C = img.shape
|
|
||||||
H_r, W_r = H % scale, W % scale
|
|
||||||
img = img[: H - H_r, : W - W_r, :]
|
|
||||||
else:
|
|
||||||
raise ValueError("Wrong img ndim: [{:d}].".format(img.ndim))
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def shave(img_in, border=0):
|
|
||||||
# img_in: Numpy, HWC or HW
|
|
||||||
img = np.copy(img_in)
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
img = img[border : h - border, border : w - border]
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# image processing process on numpy image
|
|
||||||
# channel_convert(in_c, tar_type, img_list):
|
|
||||||
# rgb2ycbcr(img, only_y=True):
|
|
||||||
# bgr2ycbcr(img, only_y=True):
|
|
||||||
# ycbcr2rgb(img):
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def rgb2ycbcr(img, only_y=True):
|
|
||||||
"""same as matlab rgb2ycbcr
|
|
||||||
only_y: only return Y channel
|
|
||||||
Input:
|
|
||||||
uint8, [0, 255]
|
|
||||||
float, [0, 1]
|
|
||||||
"""
|
|
||||||
in_img_type = img.dtype
|
|
||||||
img.astype(np.float32)
|
|
||||||
if in_img_type != np.uint8:
|
|
||||||
img *= 255.0
|
|
||||||
# convert
|
|
||||||
if only_y:
|
|
||||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
|
||||||
else:
|
|
||||||
rlt = np.matmul(
|
|
||||||
img,
|
|
||||||
[
|
|
||||||
[65.481, -37.797, 112.0],
|
|
||||||
[128.553, -74.203, -93.786],
|
|
||||||
[24.966, 112.0, -18.214],
|
|
||||||
],
|
|
||||||
) / 255.0 + [16, 128, 128]
|
|
||||||
if in_img_type == np.uint8:
|
|
||||||
rlt = rlt.round()
|
|
||||||
else:
|
|
||||||
rlt /= 255.0
|
|
||||||
return rlt.astype(in_img_type)
|
|
||||||
|
|
||||||
|
|
||||||
def ycbcr2rgb(img):
|
|
||||||
"""same as matlab ycbcr2rgb
|
|
||||||
Input:
|
|
||||||
uint8, [0, 255]
|
|
||||||
float, [0, 1]
|
|
||||||
"""
|
|
||||||
in_img_type = img.dtype
|
|
||||||
img.astype(np.float32)
|
|
||||||
if in_img_type != np.uint8:
|
|
||||||
img *= 255.0
|
|
||||||
# convert
|
|
||||||
rlt = np.matmul(
|
|
||||||
img,
|
|
||||||
[
|
|
||||||
[0.00456621, 0.00456621, 0.00456621],
|
|
||||||
[0, -0.00153632, 0.00791071],
|
|
||||||
[0.00625893, -0.00318811, 0],
|
|
||||||
],
|
|
||||||
) * 255.0 + [-222.921, 135.576, -276.836]
|
|
||||||
if in_img_type == np.uint8:
|
|
||||||
rlt = rlt.round()
|
|
||||||
else:
|
|
||||||
rlt /= 255.0
|
|
||||||
return rlt.astype(in_img_type)
|
|
||||||
|
|
||||||
|
|
||||||
def bgr2ycbcr(img, only_y=True):
|
|
||||||
"""bgr version of rgb2ycbcr
|
|
||||||
only_y: only return Y channel
|
|
||||||
Input:
|
|
||||||
uint8, [0, 255]
|
|
||||||
float, [0, 1]
|
|
||||||
"""
|
|
||||||
in_img_type = img.dtype
|
|
||||||
img.astype(np.float32)
|
|
||||||
if in_img_type != np.uint8:
|
|
||||||
img *= 255.0
|
|
||||||
# convert
|
|
||||||
if only_y:
|
|
||||||
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
|
||||||
else:
|
|
||||||
rlt = np.matmul(
|
|
||||||
img,
|
|
||||||
[
|
|
||||||
[24.966, 112.0, -18.214],
|
|
||||||
[128.553, -74.203, -93.786],
|
|
||||||
[65.481, -37.797, 112.0],
|
|
||||||
],
|
|
||||||
) / 255.0 + [16, 128, 128]
|
|
||||||
if in_img_type == np.uint8:
|
|
||||||
rlt = rlt.round()
|
|
||||||
else:
|
|
||||||
rlt /= 255.0
|
|
||||||
return rlt.astype(in_img_type)
|
|
||||||
|
|
||||||
|
|
||||||
def channel_convert(in_c, tar_type, img_list):
|
|
||||||
# conversion among BGR, gray and y
|
|
||||||
if in_c == 3 and tar_type == "gray": # BGR to gray
|
|
||||||
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
|
||||||
return [np.expand_dims(img, axis=2) for img in gray_list]
|
|
||||||
elif in_c == 3 and tar_type == "y": # BGR to y
|
|
||||||
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
|
||||||
return [np.expand_dims(img, axis=2) for img in y_list]
|
|
||||||
elif in_c == 1 and tar_type == "RGB": # gray/y to BGR
|
|
||||||
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
|
||||||
else:
|
|
||||||
return img_list
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# metric, PSNR and SSIM
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# PSNR
|
|
||||||
# --------------------------------------------
|
|
||||||
def calculate_psnr(img1, img2, border=0):
|
|
||||||
# img1 and img2 have range [0, 255]
|
|
||||||
# img1 = img1.squeeze()
|
|
||||||
# img2 = img2.squeeze()
|
|
||||||
if not img1.shape == img2.shape:
|
|
||||||
raise ValueError("Input images must have the same dimensions.")
|
|
||||||
h, w = img1.shape[:2]
|
|
||||||
img1 = img1[border : h - border, border : w - border]
|
|
||||||
img2 = img2[border : h - border, border : w - border]
|
|
||||||
|
|
||||||
img1 = img1.astype(np.float64)
|
|
||||||
img2 = img2.astype(np.float64)
|
|
||||||
mse = np.mean((img1 - img2) ** 2)
|
|
||||||
if mse == 0:
|
|
||||||
return float("inf")
|
|
||||||
return 20 * math.log10(255.0 / math.sqrt(mse))
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# SSIM
|
|
||||||
# --------------------------------------------
|
|
||||||
def calculate_ssim(img1, img2, border=0):
|
|
||||||
"""calculate SSIM
|
|
||||||
the same outputs as MATLAB's
|
|
||||||
img1, img2: [0, 255]
|
|
||||||
"""
|
|
||||||
# img1 = img1.squeeze()
|
|
||||||
# img2 = img2.squeeze()
|
|
||||||
if not img1.shape == img2.shape:
|
|
||||||
raise ValueError("Input images must have the same dimensions.")
|
|
||||||
h, w = img1.shape[:2]
|
|
||||||
img1 = img1[border : h - border, border : w - border]
|
|
||||||
img2 = img2[border : h - border, border : w - border]
|
|
||||||
|
|
||||||
if img1.ndim == 2:
|
|
||||||
return ssim(img1, img2)
|
|
||||||
elif img1.ndim == 3:
|
|
||||||
if img1.shape[2] == 3:
|
|
||||||
ssims = []
|
|
||||||
for i in range(3):
|
|
||||||
ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
|
|
||||||
return np.array(ssims).mean()
|
|
||||||
elif img1.shape[2] == 1:
|
|
||||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
|
||||||
else:
|
|
||||||
raise ValueError("Wrong input image dimensions.")
|
|
||||||
|
|
||||||
|
|
||||||
def ssim(img1, img2):
|
|
||||||
C1 = (0.01 * 255) ** 2
|
|
||||||
C2 = (0.03 * 255) ** 2
|
|
||||||
|
|
||||||
img1 = img1.astype(np.float64)
|
|
||||||
img2 = img2.astype(np.float64)
|
|
||||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
|
||||||
window = np.outer(kernel, kernel.transpose())
|
|
||||||
|
|
||||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
|
||||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
|
||||||
mu1_sq = mu1**2
|
|
||||||
mu2_sq = mu2**2
|
|
||||||
mu1_mu2 = mu1 * mu2
|
|
||||||
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
|
||||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
|
||||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
|
||||||
|
|
||||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
|
||||||
return ssim_map.mean()
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# matlab's bicubic imresize (numpy and torch) [0, 1]
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# matlab 'imresize' function, now only support 'bicubic'
|
|
||||||
def cubic(x):
|
|
||||||
absx = torch.abs(x)
|
|
||||||
absx2 = absx**2
|
|
||||||
absx3 = absx**3
|
|
||||||
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
|
|
||||||
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
|
|
||||||
) * (((absx > 1) * (absx <= 2)).type_as(absx))
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
|
||||||
if (scale < 1) and (antialiasing):
|
|
||||||
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
|
||||||
kernel_width = kernel_width / scale
|
|
||||||
|
|
||||||
# Output-space coordinates
|
|
||||||
x = torch.linspace(1, out_length, out_length)
|
|
||||||
|
|
||||||
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
|
||||||
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
|
||||||
# space maps to 1.5 in input space.
|
|
||||||
u = x / scale + 0.5 * (1 - 1 / scale)
|
|
||||||
|
|
||||||
# What is the left-most pixel that can be involved in the computation?
|
|
||||||
left = torch.floor(u - kernel_width / 2)
|
|
||||||
|
|
||||||
# What is the maximum number of pixels that can be involved in the
|
|
||||||
# computation? Note: it's OK to use an extra pixel here; if the
|
|
||||||
# corresponding weights are all zero, it will be eliminated at the end
|
|
||||||
# of this function.
|
|
||||||
P = math.ceil(kernel_width) + 2
|
|
||||||
|
|
||||||
# The indices of the input pixels involved in computing the k-th output
|
|
||||||
# pixel are in row k of the indices matrix.
|
|
||||||
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(1, P).expand(
|
|
||||||
out_length, P
|
|
||||||
)
|
|
||||||
|
|
||||||
# The weights used to compute the k-th output pixel are in row k of the
|
|
||||||
# weights matrix.
|
|
||||||
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
|
||||||
# apply cubic kernel
|
|
||||||
if (scale < 1) and (antialiasing):
|
|
||||||
weights = scale * cubic(distance_to_center * scale)
|
|
||||||
else:
|
|
||||||
weights = cubic(distance_to_center)
|
|
||||||
# Normalize the weights matrix so that each row sums to 1.
|
|
||||||
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
|
||||||
weights = weights / weights_sum.expand(out_length, P)
|
|
||||||
|
|
||||||
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
|
||||||
weights_zero_tmp = torch.sum((weights == 0), 0)
|
|
||||||
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
|
||||||
indices = indices.narrow(1, 1, P - 2)
|
|
||||||
weights = weights.narrow(1, 1, P - 2)
|
|
||||||
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
|
||||||
indices = indices.narrow(1, 0, P - 2)
|
|
||||||
weights = weights.narrow(1, 0, P - 2)
|
|
||||||
weights = weights.contiguous()
|
|
||||||
indices = indices.contiguous()
|
|
||||||
sym_len_s = -indices.min() + 1
|
|
||||||
sym_len_e = indices.max() - in_length
|
|
||||||
indices = indices + sym_len_s - 1
|
|
||||||
return weights, indices, int(sym_len_s), int(sym_len_e)
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# imresize for tensor image [0, 1]
|
|
||||||
# --------------------------------------------
|
|
||||||
def imresize(img, scale, antialiasing=True):
|
|
||||||
# Now the scale should be the same for H and W
|
|
||||||
# input: img: pytorch tensor, CHW or HW [0,1]
|
|
||||||
# output: CHW or HW [0,1] w/o round
|
|
||||||
need_squeeze = True if img.dim() == 2 else False
|
|
||||||
if need_squeeze:
|
|
||||||
img.unsqueeze_(0)
|
|
||||||
in_C, in_H, in_W = img.size()
|
|
||||||
out_C, out_H, out_W = (
|
|
||||||
in_C,
|
|
||||||
math.ceil(in_H * scale),
|
|
||||||
math.ceil(in_W * scale),
|
|
||||||
)
|
|
||||||
kernel_width = 4
|
|
||||||
kernel = "cubic"
|
|
||||||
|
|
||||||
# Return the desired dimension order for performing the resize. The
|
|
||||||
# strategy is to perform the resize first along the dimension with the
|
|
||||||
# smallest scale factor.
|
|
||||||
# Now we do not support this.
|
|
||||||
|
|
||||||
# get weights and indices
|
|
||||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
|
||||||
in_H, out_H, scale, kernel, kernel_width, antialiasing
|
|
||||||
)
|
|
||||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
|
||||||
in_W, out_W, scale, kernel, kernel_width, antialiasing
|
|
||||||
)
|
|
||||||
# process H dimension
|
|
||||||
# symmetric copying
|
|
||||||
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
|
||||||
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
|
||||||
|
|
||||||
sym_patch = img[:, :sym_len_Hs, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
||||||
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
sym_patch = img[:, -sym_len_He:, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
||||||
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
|
||||||
kernel_width = weights_H.size(1)
|
|
||||||
for i in range(out_H):
|
|
||||||
idx = int(indices_H[i][0])
|
|
||||||
for j in range(out_C):
|
|
||||||
out_1[j, i, :] = img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
|
||||||
|
|
||||||
# process W dimension
|
|
||||||
# symmetric copying
|
|
||||||
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
|
||||||
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
|
||||||
|
|
||||||
sym_patch = out_1[:, :, :sym_len_Ws]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
|
||||||
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
sym_patch = out_1[:, :, -sym_len_We:]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
|
||||||
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
|
||||||
kernel_width = weights_W.size(1)
|
|
||||||
for i in range(out_W):
|
|
||||||
idx = int(indices_W[i][0])
|
|
||||||
for j in range(out_C):
|
|
||||||
out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(weights_W[i])
|
|
||||||
if need_squeeze:
|
|
||||||
out_2.squeeze_()
|
|
||||||
return out_2
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# imresize for numpy image [0, 1]
|
|
||||||
# --------------------------------------------
|
|
||||||
def imresize_np(img, scale, antialiasing=True):
|
|
||||||
# Now the scale should be the same for H and W
|
|
||||||
# input: img: Numpy, HWC or HW [0,1]
|
|
||||||
# output: HWC or HW [0,1] w/o round
|
|
||||||
img = torch.from_numpy(img)
|
|
||||||
need_squeeze = True if img.dim() == 2 else False
|
|
||||||
if need_squeeze:
|
|
||||||
img.unsqueeze_(2)
|
|
||||||
|
|
||||||
in_H, in_W, in_C = img.size()
|
|
||||||
out_C, out_H, out_W = (
|
|
||||||
in_C,
|
|
||||||
math.ceil(in_H * scale),
|
|
||||||
math.ceil(in_W * scale),
|
|
||||||
)
|
|
||||||
kernel_width = 4
|
|
||||||
kernel = "cubic"
|
|
||||||
|
|
||||||
# Return the desired dimension order for performing the resize. The
|
|
||||||
# strategy is to perform the resize first along the dimension with the
|
|
||||||
# smallest scale factor.
|
|
||||||
# Now we do not support this.
|
|
||||||
|
|
||||||
# get weights and indices
|
|
||||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
|
||||||
in_H, out_H, scale, kernel, kernel_width, antialiasing
|
|
||||||
)
|
|
||||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
|
||||||
in_W, out_W, scale, kernel, kernel_width, antialiasing
|
|
||||||
)
|
|
||||||
# process H dimension
|
|
||||||
# symmetric copying
|
|
||||||
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
|
||||||
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
|
||||||
|
|
||||||
sym_patch = img[:sym_len_Hs, :, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
|
||||||
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
sym_patch = img[-sym_len_He:, :, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
|
||||||
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
|
||||||
kernel_width = weights_H.size(1)
|
|
||||||
for i in range(out_H):
|
|
||||||
idx = int(indices_H[i][0])
|
|
||||||
for j in range(out_C):
|
|
||||||
out_1[i, :, j] = img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
|
|
||||||
|
|
||||||
# process W dimension
|
|
||||||
# symmetric copying
|
|
||||||
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
|
||||||
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
|
||||||
|
|
||||||
sym_patch = out_1[:, :sym_len_Ws, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
||||||
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
sym_patch = out_1[:, -sym_len_We:, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
||||||
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
|
||||||
kernel_width = weights_W.size(1)
|
|
||||||
for i in range(out_W):
|
|
||||||
idx = int(indices_W[i][0])
|
|
||||||
for j in range(out_C):
|
|
||||||
out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(weights_W[i])
|
|
||||||
if need_squeeze:
|
|
||||||
out_2.squeeze_()
|
|
||||||
|
|
||||||
return out_2.numpy()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("---")
|
|
||||||
# img = imread_uint('test.bmp', 3)
|
|
||||||
# img = uint2single(img)
|
|
||||||
# img_bicubic = imresize_np(img, 1/4)
|
|
@ -10,7 +10,6 @@ from .devices import ( # noqa: F401
|
|||||||
normalize_device,
|
normalize_device,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
)
|
)
|
||||||
from .log import write_log # noqa: F401
|
|
||||||
from .util import ( # noqa: F401
|
from .util import ( # noqa: F401
|
||||||
ask_user,
|
ask_user,
|
||||||
download_with_resume,
|
download_with_resume,
|
||||||
|
@ -1,283 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "ycYWcsEKc6w7"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"# Stable Diffusion AI Notebook (Release 2.0.0)\n",
|
|
||||||
"\n",
|
|
||||||
"<img src=\"https://user-images.githubusercontent.com/60411196/186547976-d9de378a-9de8-4201-9c25-c057a9c59bad.jpeg\" alt=\"stable-diffusion-ai\" width=\"170px\"/> <br>\n",
|
|
||||||
"#### Instructions:\n",
|
|
||||||
"1. Execute each cell in order to mount a Dream bot and create images from text. <br>\n",
|
|
||||||
"2. Once cells 1-8 were run correctly you'll be executing a terminal in cell #9, you'll need to enter `python scripts/dream.py` command to run Dream bot.<br> \n",
|
|
||||||
"3. After launching dream bot, you'll see: <br> `Dream > ` in terminal. <br> Insert a command, eg. `Dream > Astronaut floating in a distant galaxy`, or type `-h` for help.\n",
|
|
||||||
"3. After completion you'll see your generated images in path `stable-diffusion/outputs/img-samples/`, you can also show last generated images in cell #10.\n",
|
|
||||||
"4. To quit Dream bot use `q` command. <br> \n",
|
|
||||||
"---\n",
|
|
||||||
"<font color=\"red\">Note:</font> It takes some time to load, but after installing all dependencies you can use the bot all time you want while colab instance is up. <br>\n",
|
|
||||||
"<font color=\"red\">Requirements:</font> For this notebook to work you need to have [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) stored in your Google Drive, it will be needed in cell #7\n",
|
|
||||||
"##### For more details visit Github repository: [invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI)\n",
|
|
||||||
"---\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "dr32VLxlnouf"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## ◢ Installation"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "a2Z5Qu_o8VtQ"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 1. Check current GPU assigned\n",
|
|
||||||
"!nvidia-smi -L\n",
|
|
||||||
"!nvidia-smi"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "vbI9ZsQHzjqF"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 2. Download stable-diffusion Repository\n",
|
|
||||||
"from os.path import exists\n",
|
|
||||||
"\n",
|
|
||||||
"!git clone --quiet https://github.com/invoke-ai/InvokeAI.git # Original repo\n",
|
|
||||||
"%cd /content/InvokeAI/\n",
|
|
||||||
"!git checkout --quiet tags/v2.0.0"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "QbXcGXYEFSNB"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 3. Install dependencies\n",
|
|
||||||
"import gc\n",
|
|
||||||
"\n",
|
|
||||||
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-base.txt\n",
|
|
||||||
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-win-colab-cuda.txt\n",
|
|
||||||
"!pip install colab-xterm\n",
|
|
||||||
"!pip install -r requirements-lin-win-colab-CUDA.txt\n",
|
|
||||||
"!pip install clean-fid torchtext\n",
|
|
||||||
"!pip install transformers\n",
|
|
||||||
"gc.collect()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "8rSMhgnAttQa"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 4. Restart Runtime\n",
|
|
||||||
"exit()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "ChIDWxLVHGGJ"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 5. Load small ML models required\n",
|
|
||||||
"import gc\n",
|
|
||||||
"\n",
|
|
||||||
"%cd /content/InvokeAI/\n",
|
|
||||||
"!python scripts/preload_models.py\n",
|
|
||||||
"gc.collect()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "795x1tMoo8b1"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## ◢ Configuration"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "YEWPV-sF1RDM"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 6. Mount google Drive\n",
|
|
||||||
"from google.colab import drive\n",
|
|
||||||
"\n",
|
|
||||||
"drive.mount(\"/content/drive\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "zRTJeZ461WGu"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 7. Drive Path to model\n",
|
|
||||||
"# @markdown Path should start with /content/drive/path-to-your-file <br>\n",
|
|
||||||
"# @markdown <font color=\"red\">Note:</font> Model should be downloaded from https://huggingface.co <br>\n",
|
|
||||||
"# @markdown Lastest release: [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)\n",
|
|
||||||
"from os.path import exists\n",
|
|
||||||
"\n",
|
|
||||||
"model_path = \"\" # @param {type:\"string\"}\n",
|
|
||||||
"if exists(model_path):\n",
|
|
||||||
" print(\"✅ Valid directory\")\n",
|
|
||||||
"else:\n",
|
|
||||||
" print(\"❌ File doesn't exist\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "UY-NNz4I8_aG"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 8. Symlink to model\n",
|
|
||||||
"\n",
|
|
||||||
"from os.path import exists\n",
|
|
||||||
"import os\n",
|
|
||||||
"\n",
|
|
||||||
"# Folder creation if it doesn't exist\n",
|
|
||||||
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1\"):\n",
|
|
||||||
" print(\"❗ Dir stable-diffusion-v1 already exists\")\n",
|
|
||||||
"else:\n",
|
|
||||||
" %mkdir /content/InvokeAI/models/ldm/stable-diffusion-v1\n",
|
|
||||||
" print(\"✅ Dir stable-diffusion-v1 created\")\n",
|
|
||||||
"\n",
|
|
||||||
"# Symbolic link if it doesn't exist\n",
|
|
||||||
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"):\n",
|
|
||||||
" print(\"❗ Symlink already created\")\n",
|
|
||||||
"else:\n",
|
|
||||||
" src = model_path\n",
|
|
||||||
" dst = \"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"\n",
|
|
||||||
" os.symlink(src, dst)\n",
|
|
||||||
" print(\"✅ Symbolic link created successfully\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "Mc28N0_NrCQH"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## ◢ Execution"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "ir4hCrMIuUpl"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 9. Run Terminal and Execute Dream bot\n",
|
|
||||||
"# @markdown <font color=\"blue\">Steps:</font> <br>\n",
|
|
||||||
"# @markdown 1. Execute command `python scripts/invoke.py` to run InvokeAI.<br>\n",
|
|
||||||
"# @markdown 2. After initialized you'll see `Dream>` line.<br>\n",
|
|
||||||
"# @markdown 3. Example text: `Astronaut floating in a distant galaxy` <br>\n",
|
|
||||||
"# @markdown 4. To quit Dream bot use: `q` command.<br>\n",
|
|
||||||
"\n",
|
|
||||||
"%load_ext colabxterm\n",
|
|
||||||
"%xterm\n",
|
|
||||||
"gc.collect()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "qnLohSHmKoGk"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"#@title 10. Show the last 15 generated images\n",
|
|
||||||
"import glob\n",
|
|
||||||
"import matplotlib.pyplot as plt\n",
|
|
||||||
"import matplotlib.image as mpimg\n",
|
|
||||||
"%matplotlib inline\n",
|
|
||||||
"\n",
|
|
||||||
"images = []\n",
|
|
||||||
"for img_path in sorted(glob.glob('/content/InvokeAI/outputs/img-samples/*.png'), reverse=True):\n",
|
|
||||||
" images.append(mpimg.imread(img_path))\n",
|
|
||||||
"\n",
|
|
||||||
"images = images[:15] \n",
|
|
||||||
"\n",
|
|
||||||
"plt.figure(figsize=(20,10))\n",
|
|
||||||
"\n",
|
|
||||||
"columns = 5\n",
|
|
||||||
"for i, image in enumerate(images):\n",
|
|
||||||
" ax = plt.subplot(len(images) / columns + 1, columns, i + 1)\n",
|
|
||||||
" ax.axes.xaxis.set_visible(False)\n",
|
|
||||||
" ax.axes.yaxis.set_visible(False)\n",
|
|
||||||
" ax.axis('off')\n",
|
|
||||||
" plt.imshow(image)\n",
|
|
||||||
" gc.collect()\n",
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"accelerator": "GPU",
|
|
||||||
"colab": {
|
|
||||||
"collapsed_sections": [],
|
|
||||||
"private_outputs": true,
|
|
||||||
"provenance": []
|
|
||||||
},
|
|
||||||
"gpuClass": "standard",
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3.9.12 64-bit",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"name": "python",
|
|
||||||
"version": "3.9.12"
|
|
||||||
},
|
|
||||||
"vscode": {
|
|
||||||
"interpreter": {
|
|
||||||
"hash": "4e870c5c5fe42db7e2c5647ae5af656ff3391bf8c2b729cbf7fa0e16ca8cb5af"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 0
|
|
||||||
}
|
|
@ -1,339 +0,0 @@
|
|||||||
from torchvision.datasets.utils import download_url
|
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
import torch
|
|
||||||
import os
|
|
||||||
|
|
||||||
# todo ?
|
|
||||||
from google.colab import files
|
|
||||||
from IPython.display import Image as ipyimg
|
|
||||||
import ipywidgets as widgets
|
|
||||||
from PIL import Image
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
import torchvision
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.util import ismap
|
|
||||||
import time
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from ldm.invoke.devices import choose_torch_device
|
|
||||||
|
|
||||||
|
|
||||||
def download_models(mode):
|
|
||||||
if mode == "superresolution":
|
|
||||||
# this is the small bsr light model
|
|
||||||
url_conf = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
|
||||||
url_ckpt = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
|
||||||
|
|
||||||
path_conf = "logs/diffusion/superresolution_bsr/configs/project.yaml"
|
|
||||||
path_ckpt = "logs/diffusion/superresolution_bsr/checkpoints/last.ckpt"
|
|
||||||
|
|
||||||
download_url(url_conf, path_conf)
|
|
||||||
download_url(url_ckpt, path_ckpt)
|
|
||||||
|
|
||||||
path_conf = path_conf + "/?dl=1" # fix it
|
|
||||||
path_ckpt = path_ckpt + "/?dl=1" # fix it
|
|
||||||
return path_conf, path_ckpt
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt):
|
|
||||||
print(f"Loading model from {ckpt}")
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
global_step = pl_sd["global_step"]
|
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
|
||||||
return {"model": model}, global_step
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(mode):
|
|
||||||
path_conf, path_ckpt = download_models(mode)
|
|
||||||
config = OmegaConf.load(path_conf)
|
|
||||||
model, step = load_model_from_config(config, path_ckpt)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def get_custom_cond(mode):
|
|
||||||
dest = "data/example_conditioning"
|
|
||||||
|
|
||||||
if mode == "superresolution":
|
|
||||||
uploaded_img = files.upload()
|
|
||||||
filename = next(iter(uploaded_img))
|
|
||||||
name, filetype = filename.split(".") # todo assumes just one dot in name !
|
|
||||||
os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}")
|
|
||||||
|
|
||||||
elif mode == "text_conditional":
|
|
||||||
w = widgets.Text(value="A cake with cream!", disabled=True)
|
|
||||||
display(w) # noqa: F821
|
|
||||||
|
|
||||||
with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", "w") as f:
|
|
||||||
f.write(w.value)
|
|
||||||
|
|
||||||
elif mode == "class_conditional":
|
|
||||||
w = widgets.IntSlider(min=0, max=1000)
|
|
||||||
display(w) # noqa: F821
|
|
||||||
with open(f"{dest}/{mode}/custom.txt", "w") as f:
|
|
||||||
f.write(w.value)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"cond not implemented for mode{mode}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_cond_options(mode):
|
|
||||||
path = "data/example_conditioning"
|
|
||||||
path = os.path.join(path, mode)
|
|
||||||
onlyfiles = [f for f in sorted(os.listdir(path))]
|
|
||||||
return path, onlyfiles
|
|
||||||
|
|
||||||
|
|
||||||
def select_cond_path(mode):
|
|
||||||
path = "data/example_conditioning" # todo
|
|
||||||
path = os.path.join(path, mode)
|
|
||||||
onlyfiles = [f for f in sorted(os.listdir(path))]
|
|
||||||
|
|
||||||
selected = widgets.RadioButtons(options=onlyfiles, description="Select conditioning:", disabled=False)
|
|
||||||
display(selected) # noqa: F821
|
|
||||||
selected_path = os.path.join(path, selected.value)
|
|
||||||
return selected_path
|
|
||||||
|
|
||||||
|
|
||||||
def get_cond(mode, selected_path):
|
|
||||||
example = dict()
|
|
||||||
if mode == "superresolution":
|
|
||||||
up_f = 4
|
|
||||||
visualize_cond_img(selected_path)
|
|
||||||
|
|
||||||
c = Image.open(selected_path)
|
|
||||||
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
|
||||||
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)
|
|
||||||
c_up = rearrange(c_up, "1 c h w -> 1 h w c")
|
|
||||||
c = rearrange(c, "1 c h w -> 1 h w c")
|
|
||||||
c = 2.0 * c - 1.0
|
|
||||||
|
|
||||||
device = choose_torch_device()
|
|
||||||
c = c.to(device)
|
|
||||||
example["LR_image"] = c
|
|
||||||
example["image"] = c_up
|
|
||||||
|
|
||||||
return example
|
|
||||||
|
|
||||||
|
|
||||||
def visualize_cond_img(path):
|
|
||||||
display(ipyimg(filename=path)) # noqa: F821
|
|
||||||
|
|
||||||
|
|
||||||
def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):
|
|
||||||
example = get_cond(task, selected_path)
|
|
||||||
|
|
||||||
save_intermediate_vid = False
|
|
||||||
n_runs = 1
|
|
||||||
masked = False
|
|
||||||
guider = None
|
|
||||||
ckwargs = None
|
|
||||||
mode = "ddim"
|
|
||||||
ddim_use_x0_pred = False
|
|
||||||
temperature = 1.0
|
|
||||||
eta = 1.0
|
|
||||||
make_progrow = True
|
|
||||||
custom_shape = None
|
|
||||||
|
|
||||||
height, width = example["image"].shape[1:3]
|
|
||||||
split_input = height >= 128 and width >= 128
|
|
||||||
|
|
||||||
if split_input:
|
|
||||||
ks = 128
|
|
||||||
stride = 64
|
|
||||||
vqf = 4 #
|
|
||||||
model.split_input_params = {
|
|
||||||
"ks": (ks, ks),
|
|
||||||
"stride": (stride, stride),
|
|
||||||
"vqf": vqf,
|
|
||||||
"patch_distributed_vq": True,
|
|
||||||
"tie_braker": False,
|
|
||||||
"clip_max_weight": 0.5,
|
|
||||||
"clip_min_weight": 0.01,
|
|
||||||
"clip_max_tie_weight": 0.5,
|
|
||||||
"clip_min_tie_weight": 0.01,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
if hasattr(model, "split_input_params"):
|
|
||||||
delattr(model, "split_input_params")
|
|
||||||
|
|
||||||
invert_mask = False
|
|
||||||
|
|
||||||
x_T = None
|
|
||||||
for n in range(n_runs):
|
|
||||||
if custom_shape is not None:
|
|
||||||
x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
|
||||||
x_T = repeat(x_T, "1 c h w -> b c h w", b=custom_shape[0])
|
|
||||||
|
|
||||||
logs = make_convolutional_sample(
|
|
||||||
example,
|
|
||||||
model,
|
|
||||||
mode=mode,
|
|
||||||
custom_steps=custom_steps,
|
|
||||||
eta=eta,
|
|
||||||
swap_mode=False,
|
|
||||||
masked=masked,
|
|
||||||
invert_mask=invert_mask,
|
|
||||||
quantize_x0=False,
|
|
||||||
custom_schedule=None,
|
|
||||||
decode_interval=10,
|
|
||||||
resize_enabled=resize_enabled,
|
|
||||||
custom_shape=custom_shape,
|
|
||||||
temperature=temperature,
|
|
||||||
noise_dropout=0.0,
|
|
||||||
corrector=guider,
|
|
||||||
corrector_kwargs=ckwargs,
|
|
||||||
x_T=x_T,
|
|
||||||
save_intermediate_vid=save_intermediate_vid,
|
|
||||||
make_progrow=make_progrow,
|
|
||||||
ddim_use_x0_pred=ddim_use_x0_pred,
|
|
||||||
)
|
|
||||||
return logs
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convsample_ddim(
|
|
||||||
model,
|
|
||||||
cond,
|
|
||||||
steps,
|
|
||||||
shape,
|
|
||||||
eta=1.0,
|
|
||||||
callback=None,
|
|
||||||
normals_sequence=None,
|
|
||||||
mask=None,
|
|
||||||
x0=None,
|
|
||||||
quantize_x0=False,
|
|
||||||
img_callback=None,
|
|
||||||
temperature=1.0,
|
|
||||||
noise_dropout=0.0,
|
|
||||||
score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
x_T=None,
|
|
||||||
log_every_t=None,
|
|
||||||
):
|
|
||||||
ddim = DDIMSampler(model)
|
|
||||||
bs = shape[0] # dont know where this comes from but wayne
|
|
||||||
shape = shape[1:] # cut batch dim
|
|
||||||
print(f"Sampling with eta = {eta}; steps: {steps}")
|
|
||||||
samples, intermediates = ddim.sample(
|
|
||||||
steps,
|
|
||||||
batch_size=bs,
|
|
||||||
shape=shape,
|
|
||||||
conditioning=cond,
|
|
||||||
callback=callback,
|
|
||||||
normals_sequence=normals_sequence,
|
|
||||||
quantize_x0=quantize_x0,
|
|
||||||
eta=eta,
|
|
||||||
mask=mask,
|
|
||||||
x0=x0,
|
|
||||||
temperature=temperature,
|
|
||||||
verbose=False,
|
|
||||||
score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
x_T=x_T,
|
|
||||||
)
|
|
||||||
|
|
||||||
return samples, intermediates
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def make_convolutional_sample(
|
|
||||||
batch,
|
|
||||||
model,
|
|
||||||
mode="vanilla",
|
|
||||||
custom_steps=None,
|
|
||||||
eta=1.0,
|
|
||||||
swap_mode=False,
|
|
||||||
masked=False,
|
|
||||||
invert_mask=True,
|
|
||||||
quantize_x0=False,
|
|
||||||
custom_schedule=None,
|
|
||||||
decode_interval=1000,
|
|
||||||
resize_enabled=False,
|
|
||||||
custom_shape=None,
|
|
||||||
temperature=1.0,
|
|
||||||
noise_dropout=0.0,
|
|
||||||
corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
x_T=None,
|
|
||||||
save_intermediate_vid=False,
|
|
||||||
make_progrow=True,
|
|
||||||
ddim_use_x0_pred=False,
|
|
||||||
):
|
|
||||||
log = dict()
|
|
||||||
|
|
||||||
z, c, x, xrec, xc = model.get_input(
|
|
||||||
batch,
|
|
||||||
model.first_stage_key,
|
|
||||||
return_first_stage_outputs=True,
|
|
||||||
force_c_encode=not (hasattr(model, "split_input_params") and model.cond_stage_key == "coordinates_bbox"),
|
|
||||||
return_original_cond=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
log_every_t = 1 if save_intermediate_vid else None
|
|
||||||
|
|
||||||
if custom_shape is not None:
|
|
||||||
z = torch.randn(custom_shape)
|
|
||||||
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
|
||||||
|
|
||||||
z0 = None
|
|
||||||
|
|
||||||
log["input"] = x
|
|
||||||
log["reconstruction"] = xrec
|
|
||||||
|
|
||||||
if ismap(xc):
|
|
||||||
log["original_conditioning"] = model.to_rgb(xc)
|
|
||||||
if hasattr(model, "cond_stage_key"):
|
|
||||||
log[model.cond_stage_key] = model.to_rgb(xc)
|
|
||||||
|
|
||||||
else:
|
|
||||||
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
|
|
||||||
if model.cond_stage_model:
|
|
||||||
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
|
|
||||||
if model.cond_stage_key == "class_label":
|
|
||||||
log[model.cond_stage_key] = xc[model.cond_stage_key]
|
|
||||||
|
|
||||||
with model.ema_scope("Plotting"):
|
|
||||||
t0 = time.time()
|
|
||||||
img_cb = None
|
|
||||||
|
|
||||||
sample, intermediates = convsample_ddim(
|
|
||||||
model,
|
|
||||||
c,
|
|
||||||
steps=custom_steps,
|
|
||||||
shape=z.shape,
|
|
||||||
eta=eta,
|
|
||||||
quantize_x0=quantize_x0,
|
|
||||||
img_callback=img_cb,
|
|
||||||
mask=None,
|
|
||||||
x0=z0,
|
|
||||||
temperature=temperature,
|
|
||||||
noise_dropout=noise_dropout,
|
|
||||||
score_corrector=corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
x_T=x_T,
|
|
||||||
log_every_t=log_every_t,
|
|
||||||
)
|
|
||||||
t1 = time.time()
|
|
||||||
|
|
||||||
if ddim_use_x0_pred:
|
|
||||||
sample = intermediates["pred_x0"][-1]
|
|
||||||
|
|
||||||
x_sample = model.decode_first_stage(sample)
|
|
||||||
|
|
||||||
try:
|
|
||||||
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
|
||||||
log["sample_noquant"] = x_sample_noquant
|
|
||||||
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
log["sample"] = x_sample
|
|
||||||
log["time"] = t1 - t0
|
|
||||||
|
|
||||||
return log
|
|
@ -1,14 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from invokeai.app.cli_app import invoke_cli
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
invoke_cli()
|
|
@ -1,4 +0,0 @@
|
|||||||
from invokeai.backend.install.migrate_to_3 import main
|
|
||||||
|
|
||||||
if __name__=='__main__':
|
|
||||||
main()
|
|
@ -1,41 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip
|
|
||||||
wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip
|
|
||||||
wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
|
|
||||||
wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
|
|
||||||
wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
|
|
||||||
wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
|
|
||||||
wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
|
|
||||||
wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
|
|
||||||
wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cd models/first_stage_models/kl-f4
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../kl-f8
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../kl-f16
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../kl-f32
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../vq-f4
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../vq-f4-noattn
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../vq-f8
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../vq-f8-n256
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../vq-f16
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../..
|
|
@ -1,49 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip
|
|
||||||
wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip
|
|
||||||
wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip
|
|
||||||
wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip
|
|
||||||
wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
|
|
||||||
wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
|
|
||||||
wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
|
|
||||||
wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
|
|
||||||
wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
|
|
||||||
wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
|
|
||||||
wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cd models/ldm/celeba256
|
|
||||||
unzip -o celeba-256.zip
|
|
||||||
|
|
||||||
cd ../ffhq256
|
|
||||||
unzip -o ffhq-256.zip
|
|
||||||
|
|
||||||
cd ../lsun_churches256
|
|
||||||
unzip -o lsun_churches-256.zip
|
|
||||||
|
|
||||||
cd ../lsun_beds256
|
|
||||||
unzip -o lsun_beds-256.zip
|
|
||||||
|
|
||||||
cd ../text2img256
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../cin256
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../semantic_synthesis512
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../semantic_synthesis256
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../bsr_sr
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../layout2img-openimages256
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../inpainting_big
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../..
|
|
@ -1,285 +0,0 @@
|
|||||||
"""make variations of input image"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import PIL
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
from itertools import islice
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from torchvision.utils import make_grid
|
|
||||||
from torch import autocast
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from pytorch_lightning import seed_everything
|
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
|
||||||
from ldm.invoke.devices import choose_torch_device
|
|
||||||
|
|
||||||
|
|
||||||
def chunk(it, size):
|
|
||||||
it = iter(it)
|
|
||||||
return iter(lambda: tuple(islice(it, size)), ())
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
|
||||||
print(f"Loading model from {ckpt}")
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
if "global_step" in pl_sd:
|
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
if len(m) > 0 and verbose:
|
|
||||||
print("missing keys:")
|
|
||||||
print(m)
|
|
||||||
if len(u) > 0 and verbose:
|
|
||||||
print("unexpected keys:")
|
|
||||||
print(u)
|
|
||||||
|
|
||||||
model.to(choose_torch_device())
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_img(path):
|
|
||||||
image = Image.open(path).convert("RGB")
|
|
||||||
w, h = image.size
|
|
||||||
print(f"loaded input image of size ({w}, {h}) from {path}")
|
|
||||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
|
||||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
|
||||||
image = torch.from_numpy(image)
|
|
||||||
return 2.0 * image - 1.0
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
default="a painting of a virus monster playing guitar",
|
|
||||||
help="the prompt to render",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image")
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_grid",
|
|
||||||
action="store_true",
|
|
||||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_save",
|
|
||||||
action="store_true",
|
|
||||||
help="do not save indiviual samples. For speed measurements.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ddim_steps",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="number of ddim sampling steps",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--plms",
|
|
||||||
action="store_true",
|
|
||||||
help="use plms sampling",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fixed_code",
|
|
||||||
action="store_true",
|
|
||||||
help="if enabled, uses the same starting code across all samples ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ddim_eta",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_iter",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="sample this often",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--C",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="latent channels",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--f",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="downsampling factor, most often 8 or 16",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_samples",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="how many samples to produce for each given prompt. A.k.a batch size",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_rows",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="rows in the grid (default: n_samples)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--scale",
|
|
||||||
type=float,
|
|
||||||
default=5.0,
|
|
||||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--strength",
|
|
||||||
type=float,
|
|
||||||
default=0.75,
|
|
||||||
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--from-file",
|
|
||||||
type=str,
|
|
||||||
help="if specified, load prompts from this file",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--config",
|
|
||||||
type=str,
|
|
||||||
default="configs/stable-diffusion/v1-inference.yaml",
|
|
||||||
help="path to config which constructs model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ckpt",
|
|
||||||
type=str,
|
|
||||||
default="models/ldm/stable-diffusion-v1/model.ckpt",
|
|
||||||
help="path to checkpoint of model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
default=42,
|
|
||||||
help="the seed (for reproducible sampling)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
|
|
||||||
)
|
|
||||||
|
|
||||||
opt = parser.parse_args()
|
|
||||||
seed_everything(opt.seed)
|
|
||||||
|
|
||||||
config = OmegaConf.load(f"{opt.config}")
|
|
||||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
|
||||||
|
|
||||||
device = torch.device(choose_torch_device())
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
if opt.plms:
|
|
||||||
raise NotImplementedError("PLMS sampler not (yet) supported")
|
|
||||||
sampler = PLMSSampler(model)
|
|
||||||
else:
|
|
||||||
sampler = DDIMSampler(model)
|
|
||||||
|
|
||||||
os.makedirs(opt.outdir, exist_ok=True)
|
|
||||||
outpath = opt.outdir
|
|
||||||
|
|
||||||
batch_size = opt.n_samples
|
|
||||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
|
||||||
if not opt.from_file:
|
|
||||||
prompt = opt.prompt
|
|
||||||
assert prompt is not None
|
|
||||||
data = [batch_size * [prompt]]
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f"reading prompts from {opt.from_file}")
|
|
||||||
with open(opt.from_file, "r") as f:
|
|
||||||
data = f.read().splitlines()
|
|
||||||
data = list(chunk(data, batch_size))
|
|
||||||
|
|
||||||
sample_path = os.path.join(outpath, "samples")
|
|
||||||
os.makedirs(sample_path, exist_ok=True)
|
|
||||||
base_count = len(os.listdir(sample_path))
|
|
||||||
grid_count = len(os.listdir(outpath)) - 1
|
|
||||||
|
|
||||||
assert os.path.isfile(opt.init_img)
|
|
||||||
init_image = load_img(opt.init_img).to(device)
|
|
||||||
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
|
||||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
|
||||||
|
|
||||||
sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
|
|
||||||
|
|
||||||
assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]"
|
|
||||||
t_enc = int(opt.strength * opt.ddim_steps)
|
|
||||||
print(f"target t_enc is {t_enc} steps")
|
|
||||||
|
|
||||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
|
||||||
if device.type in ["mps", "cpu"]:
|
|
||||||
precision_scope = nullcontext # have to use f32 on mps
|
|
||||||
with torch.no_grad():
|
|
||||||
with precision_scope(device.type):
|
|
||||||
with model.ema_scope():
|
|
||||||
all_samples = list()
|
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
|
||||||
for prompts in tqdm(data, desc="data"):
|
|
||||||
uc = None
|
|
||||||
if opt.scale != 1.0:
|
|
||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
|
||||||
if isinstance(prompts, tuple):
|
|
||||||
prompts = list(prompts)
|
|
||||||
c = model.get_learned_conditioning(prompts)
|
|
||||||
|
|
||||||
# encode (scaled latent)
|
|
||||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device))
|
|
||||||
# decode it
|
|
||||||
samples = sampler.decode(
|
|
||||||
z_enc,
|
|
||||||
c,
|
|
||||||
t_enc,
|
|
||||||
unconditional_guidance_scale=opt.scale,
|
|
||||||
unconditional_conditioning=uc,
|
|
||||||
)
|
|
||||||
|
|
||||||
x_samples = model.decode_first_stage(samples)
|
|
||||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
|
|
||||||
if not opt.skip_save:
|
|
||||||
for x_sample in x_samples:
|
|
||||||
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
|
||||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
|
||||||
os.path.join(sample_path, f"{base_count:05}.png")
|
|
||||||
)
|
|
||||||
base_count += 1
|
|
||||||
all_samples.append(x_samples)
|
|
||||||
|
|
||||||
if not opt.skip_grid:
|
|
||||||
# additionally, save as grid
|
|
||||||
grid = torch.stack(all_samples, 0)
|
|
||||||
grid = rearrange(grid, "n b c h w -> (n b) c h w")
|
|
||||||
grid = make_grid(grid, nrow=n_rows)
|
|
||||||
|
|
||||||
# to image
|
|
||||||
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
|
|
||||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
|
||||||
grid_count += 1
|
|
||||||
|
|
||||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,94 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from main import instantiate_from_config
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.invoke.devices import choose_torch_device
|
|
||||||
|
|
||||||
|
|
||||||
def make_batch(image, mask, device):
|
|
||||||
image = np.array(Image.open(image).convert("RGB"))
|
|
||||||
image = image.astype(np.float32) / 255.0
|
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
|
||||||
image = torch.from_numpy(image)
|
|
||||||
|
|
||||||
mask = np.array(Image.open(mask).convert("L"))
|
|
||||||
mask = mask.astype(np.float32) / 255.0
|
|
||||||
mask = mask[None, None]
|
|
||||||
mask[mask < 0.5] = 0
|
|
||||||
mask[mask >= 0.5] = 1
|
|
||||||
mask = torch.from_numpy(mask)
|
|
||||||
|
|
||||||
masked_image = (1 - mask) * image
|
|
||||||
|
|
||||||
batch = {"image": image, "mask": mask, "masked_image": masked_image}
|
|
||||||
for k in batch:
|
|
||||||
batch[k] = batch[k].to(device=device)
|
|
||||||
batch[k] = batch[k] * 2.0 - 1.0
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--indir",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--outdir",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
help="dir to write results to",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--steps",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="number of ddim sampling steps",
|
|
||||||
)
|
|
||||||
opt = parser.parse_args()
|
|
||||||
|
|
||||||
masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
|
|
||||||
images = [x.replace("_mask.png", ".png") for x in masks]
|
|
||||||
print(f"Found {len(masks)} inputs.")
|
|
||||||
|
|
||||||
config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False)
|
|
||||||
|
|
||||||
device = choose_torch_device()
|
|
||||||
model = model.to(device)
|
|
||||||
sampler = DDIMSampler(model)
|
|
||||||
|
|
||||||
os.makedirs(opt.outdir, exist_ok=True)
|
|
||||||
with torch.no_grad():
|
|
||||||
with model.ema_scope():
|
|
||||||
for image, mask in tqdm(zip(images, masks)):
|
|
||||||
outpath = os.path.join(opt.outdir, os.path.split(image)[1])
|
|
||||||
batch = make_batch(image, mask, device=device)
|
|
||||||
|
|
||||||
# encode masked image and concat downsampled mask
|
|
||||||
c = model.cond_stage_model.encode(batch["masked_image"])
|
|
||||||
cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:])
|
|
||||||
c = torch.cat((c, cc), dim=1)
|
|
||||||
|
|
||||||
shape = (c.shape[1] - 1,) + c.shape[2:]
|
|
||||||
samples_ddim, _ = sampler.sample(
|
|
||||||
S=opt.steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False
|
|
||||||
)
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
|
||||||
|
|
||||||
image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
|
|
||||||
inpainted = (1 - mask) * image + mask * predicted_image
|
|
||||||
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
|
||||||
Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
|
|
@ -1,397 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
from itertools import islice
|
|
||||||
from einops import rearrange
|
|
||||||
from torchvision.utils import make_grid
|
|
||||||
import scann
|
|
||||||
import time
|
|
||||||
from multiprocessing import cpu_count
|
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config, parallel_data_prefetch
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
|
||||||
from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
|
|
||||||
|
|
||||||
DATABASES = [
|
|
||||||
"openimages",
|
|
||||||
"artbench-art_nouveau",
|
|
||||||
"artbench-baroque",
|
|
||||||
"artbench-expressionism",
|
|
||||||
"artbench-impressionism",
|
|
||||||
"artbench-post_impressionism",
|
|
||||||
"artbench-realism",
|
|
||||||
"artbench-romanticism",
|
|
||||||
"artbench-renaissance",
|
|
||||||
"artbench-surrealism",
|
|
||||||
"artbench-ukiyo_e",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def chunk(it, size):
|
|
||||||
it = iter(it)
|
|
||||||
return iter(lambda: tuple(islice(it, size)), ())
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
|
||||||
print(f"Loading model from {ckpt}")
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
if "global_step" in pl_sd:
|
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
if len(m) > 0 and verbose:
|
|
||||||
print("missing keys:")
|
|
||||||
print(m)
|
|
||||||
if len(u) > 0 and verbose:
|
|
||||||
print("unexpected keys:")
|
|
||||||
print(u)
|
|
||||||
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class Searcher(object):
|
|
||||||
def __init__(self, database, retriever_version="ViT-L/14"):
|
|
||||||
assert database in DATABASES
|
|
||||||
# self.database = self.load_database(database)
|
|
||||||
self.database_name = database
|
|
||||||
self.searcher_savedir = f"data/rdm/searchers/{self.database_name}"
|
|
||||||
self.database_path = f"data/rdm/retrieval_databases/{self.database_name}"
|
|
||||||
self.retriever = self.load_retriever(version=retriever_version)
|
|
||||||
self.database = {"embedding": [], "img_id": [], "patch_coords": []}
|
|
||||||
self.load_database()
|
|
||||||
self.load_searcher()
|
|
||||||
|
|
||||||
def train_searcher(self, k, metric="dot_product", searcher_savedir=None):
|
|
||||||
print("Start training searcher")
|
|
||||||
searcher = scann.scann_ops_pybind.builder(
|
|
||||||
self.database["embedding"] / np.linalg.norm(self.database["embedding"], axis=1)[:, np.newaxis], k, metric
|
|
||||||
)
|
|
||||||
self.searcher = searcher.score_brute_force().build()
|
|
||||||
print("Finish training searcher")
|
|
||||||
|
|
||||||
if searcher_savedir is not None:
|
|
||||||
print(f'Save trained searcher under "{searcher_savedir}"')
|
|
||||||
os.makedirs(searcher_savedir, exist_ok=True)
|
|
||||||
self.searcher.serialize(searcher_savedir)
|
|
||||||
|
|
||||||
def load_single_file(self, saved_embeddings):
|
|
||||||
compressed = np.load(saved_embeddings)
|
|
||||||
self.database = {key: compressed[key] for key in compressed.files}
|
|
||||||
print("Finished loading of clip embeddings.")
|
|
||||||
|
|
||||||
def load_multi_files(self, data_archive):
|
|
||||||
out_data = {key: [] for key in self.database}
|
|
||||||
for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
|
|
||||||
for key in d.files:
|
|
||||||
out_data[key].append(d[key])
|
|
||||||
|
|
||||||
return out_data
|
|
||||||
|
|
||||||
def load_database(self):
|
|
||||||
print(f'Load saved patch embedding from "{self.database_path}"')
|
|
||||||
file_content = glob.glob(os.path.join(self.database_path, "*.npz"))
|
|
||||||
|
|
||||||
if len(file_content) == 1:
|
|
||||||
self.load_single_file(file_content[0])
|
|
||||||
elif len(file_content) > 1:
|
|
||||||
data = [np.load(f) for f in file_content]
|
|
||||||
prefetched_data = parallel_data_prefetch(
|
|
||||||
self.load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.database = {
|
|
||||||
key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in self.database
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
|
|
||||||
|
|
||||||
print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
|
|
||||||
|
|
||||||
def load_retriever(
|
|
||||||
self,
|
|
||||||
version="ViT-L/14",
|
|
||||||
):
|
|
||||||
model = FrozenClipImageEmbedder(model=version)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
def load_searcher(self):
|
|
||||||
print(f"load searcher for database {self.database_name} from {self.searcher_savedir}")
|
|
||||||
self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
|
|
||||||
print("Finished loading searcher.")
|
|
||||||
|
|
||||||
def search(self, x, k):
|
|
||||||
if self.searcher is None and self.database["embedding"].shape[0] < 2e4:
|
|
||||||
self.train_searcher(k) # quickly fit searcher on the fly for small databases
|
|
||||||
assert self.searcher is not None, "Cannot search with uninitialized searcher"
|
|
||||||
if isinstance(x, torch.Tensor):
|
|
||||||
x = x.detach().cpu().numpy()
|
|
||||||
if len(x.shape) == 3:
|
|
||||||
x = x[:, 0]
|
|
||||||
query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
out_embeddings = self.database["embedding"][nns]
|
|
||||||
out_img_ids = self.database["img_id"][nns]
|
|
||||||
out_pc = self.database["patch_coords"][nns]
|
|
||||||
|
|
||||||
out = {
|
|
||||||
"nn_embeddings": out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
|
|
||||||
"img_ids": out_img_ids,
|
|
||||||
"patch_coords": out_pc,
|
|
||||||
"queries": x,
|
|
||||||
"exec_time": end - start,
|
|
||||||
"nns": nns,
|
|
||||||
"q_embeddings": query_embeddings,
|
|
||||||
}
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __call__(self, x, n):
|
|
||||||
return self.search(x, n)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
# TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
|
|
||||||
# TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
default="a painting of a virus monster playing guitar",
|
|
||||||
help="the prompt to render",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_grid",
|
|
||||||
action="store_true",
|
|
||||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ddim_steps",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="number of ddim sampling steps",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_repeat",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="number of repeats in CLIP latent space",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--plms",
|
|
||||||
action="store_true",
|
|
||||||
help="use plms sampling",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ddim_eta",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_iter",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="sample this often",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--H",
|
|
||||||
type=int,
|
|
||||||
default=768,
|
|
||||||
help="image height, in pixel space",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--W",
|
|
||||||
type=int,
|
|
||||||
default=768,
|
|
||||||
help="image width, in pixel space",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_samples",
|
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="how many samples to produce for each given prompt. A.k.a batch size",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_rows",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="rows in the grid (default: n_samples)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--scale",
|
|
||||||
type=float,
|
|
||||||
default=5.0,
|
|
||||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--from-file",
|
|
||||||
type=str,
|
|
||||||
help="if specified, load prompts from this file",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--config",
|
|
||||||
type=str,
|
|
||||||
default="configs/retrieval-augmented-diffusion/768x768.yaml",
|
|
||||||
help="path to config which constructs model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ckpt",
|
|
||||||
type=str,
|
|
||||||
default="models/rdm/rdm768x768/model.ckpt",
|
|
||||||
help="path to checkpoint of model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--clip_type",
|
|
||||||
type=str,
|
|
||||||
default="ViT-L/14",
|
|
||||||
help="which CLIP model to use for retrieval and NN encoding",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--database",
|
|
||||||
type=str,
|
|
||||||
default="artbench-surrealism",
|
|
||||||
choices=DATABASES,
|
|
||||||
help="The database used for the search, only applied when --use_neighbors=True",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_neighbors",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="Include neighbors in addition to text prompt for conditioning",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--knn",
|
|
||||||
default=10,
|
|
||||||
type=int,
|
|
||||||
help="The number of included neighbors, only applied when --use_neighbors=True",
|
|
||||||
)
|
|
||||||
|
|
||||||
opt = parser.parse_args()
|
|
||||||
|
|
||||||
config = OmegaConf.load(f"{opt.config}")
|
|
||||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
|
||||||
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)
|
|
||||||
|
|
||||||
if opt.plms:
|
|
||||||
sampler = PLMSSampler(model)
|
|
||||||
else:
|
|
||||||
sampler = DDIMSampler(model)
|
|
||||||
|
|
||||||
os.makedirs(opt.outdir, exist_ok=True)
|
|
||||||
outpath = opt.outdir
|
|
||||||
|
|
||||||
batch_size = opt.n_samples
|
|
||||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
|
||||||
if not opt.from_file:
|
|
||||||
prompt = opt.prompt
|
|
||||||
assert prompt is not None
|
|
||||||
data = [batch_size * [prompt]]
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f"reading prompts from {opt.from_file}")
|
|
||||||
with open(opt.from_file, "r") as f:
|
|
||||||
data = f.read().splitlines()
|
|
||||||
data = list(chunk(data, batch_size))
|
|
||||||
|
|
||||||
sample_path = os.path.join(outpath, "samples")
|
|
||||||
os.makedirs(sample_path, exist_ok=True)
|
|
||||||
base_count = len(os.listdir(sample_path))
|
|
||||||
grid_count = len(os.listdir(outpath)) - 1
|
|
||||||
|
|
||||||
print(f"sampling scale for cfg is {opt.scale:.2f}")
|
|
||||||
|
|
||||||
searcher = None
|
|
||||||
if opt.use_neighbors:
|
|
||||||
searcher = Searcher(opt.database)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
with model.ema_scope():
|
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
|
||||||
all_samples = list()
|
|
||||||
for prompts in tqdm(data, desc="data"):
|
|
||||||
print("sampling prompts:", prompts)
|
|
||||||
if isinstance(prompts, tuple):
|
|
||||||
prompts = list(prompts)
|
|
||||||
c = clip_text_encoder.encode(prompts)
|
|
||||||
uc = None
|
|
||||||
if searcher is not None:
|
|
||||||
nn_dict = searcher(c, opt.knn)
|
|
||||||
c = torch.cat([c, torch.from_numpy(nn_dict["nn_embeddings"]).cuda()], dim=1)
|
|
||||||
if opt.scale != 1.0:
|
|
||||||
uc = torch.zeros_like(c)
|
|
||||||
if isinstance(prompts, tuple):
|
|
||||||
prompts = list(prompts)
|
|
||||||
shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
|
|
||||||
samples_ddim, _ = sampler.sample(
|
|
||||||
S=opt.ddim_steps,
|
|
||||||
conditioning=c,
|
|
||||||
batch_size=c.shape[0],
|
|
||||||
shape=shape,
|
|
||||||
verbose=False,
|
|
||||||
unconditional_guidance_scale=opt.scale,
|
|
||||||
unconditional_conditioning=uc,
|
|
||||||
eta=opt.ddim_eta,
|
|
||||||
)
|
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
|
|
||||||
for x_sample in x_samples_ddim:
|
|
||||||
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
|
||||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
|
||||||
os.path.join(sample_path, f"{base_count:05}.png")
|
|
||||||
)
|
|
||||||
base_count += 1
|
|
||||||
all_samples.append(x_samples_ddim)
|
|
||||||
|
|
||||||
if not opt.skip_grid:
|
|
||||||
# additionally, save as grid
|
|
||||||
grid = torch.stack(all_samples, 0)
|
|
||||||
grid = rearrange(grid, "n b c h w -> (n b) c h w")
|
|
||||||
grid = make_grid(grid, nrow=n_rows)
|
|
||||||
|
|
||||||
# to image
|
|
||||||
grid_np = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
|
|
||||||
Image.fromarray(grid_np.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
|
||||||
grid_count += 1
|
|
||||||
|
|
||||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
|
|
File diff suppressed because one or more lines are too long
@ -1,898 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import datetime
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
|
|
||||||
from packaging import version
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
from functools import partial
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from pytorch_lightning import seed_everything
|
|
||||||
from pytorch_lightning.trainer import Trainer
|
|
||||||
from pytorch_lightning.callbacks import Callback
|
|
||||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
|
||||||
from pytorch_lightning.utilities import rank_zero_info
|
|
||||||
|
|
||||||
from ldm.data.base import Txt2ImgIterableBaseDataset
|
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
|
|
||||||
|
|
||||||
def fix_func(orig):
|
|
||||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
||||||
|
|
||||||
def new_func(*args, **kw):
|
|
||||||
device = kw.get("device", "mps")
|
|
||||||
kw["device"] = "cpu"
|
|
||||||
return orig(*args, **kw).to(device)
|
|
||||||
|
|
||||||
return new_func
|
|
||||||
return orig
|
|
||||||
|
|
||||||
|
|
||||||
torch.rand = fix_func(torch.rand)
|
|
||||||
torch.rand_like = fix_func(torch.rand_like)
|
|
||||||
torch.randn = fix_func(torch.randn)
|
|
||||||
torch.randn_like = fix_func(torch.randn_like)
|
|
||||||
torch.randint = fix_func(torch.randint)
|
|
||||||
torch.randint_like = fix_func(torch.randint_like)
|
|
||||||
torch.bernoulli = fix_func(torch.bernoulli)
|
|
||||||
torch.multinomial = fix_func(torch.multinomial)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
|
||||||
print(f"Loading model from {ckpt}")
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
config.model.params.ckpt_path = ckpt
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
if len(m) > 0 and verbose:
|
|
||||||
print("missing keys:")
|
|
||||||
print(m)
|
|
||||||
if len(u) > 0 and verbose:
|
|
||||||
print("unexpected keys:")
|
|
||||||
print(u)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
model.cuda()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser(**parser_kwargs):
|
|
||||||
def str2bool(v):
|
|
||||||
if isinstance(v, bool):
|
|
||||||
return v
|
|
||||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
|
||||||
return True
|
|
||||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
|
||||||
parser.add_argument(
|
|
||||||
"-n",
|
|
||||||
"--name",
|
|
||||||
type=str,
|
|
||||||
const=True,
|
|
||||||
default="",
|
|
||||||
nargs="?",
|
|
||||||
help="postfix for logdir",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-r",
|
|
||||||
"--resume",
|
|
||||||
type=str,
|
|
||||||
const=True,
|
|
||||||
default="",
|
|
||||||
nargs="?",
|
|
||||||
help="resume from logdir or checkpoint in logdir",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-b",
|
|
||||||
"--base",
|
|
||||||
nargs="*",
|
|
||||||
metavar="base_config.yaml",
|
|
||||||
help="paths to base configs. Loaded from left-to-right. "
|
|
||||||
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
|
||||||
default=list(),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-t",
|
|
||||||
"--train",
|
|
||||||
type=str2bool,
|
|
||||||
const=True,
|
|
||||||
default=False,
|
|
||||||
nargs="?",
|
|
||||||
help="train",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-test",
|
|
||||||
type=str2bool,
|
|
||||||
const=True,
|
|
||||||
default=False,
|
|
||||||
nargs="?",
|
|
||||||
help="disable test",
|
|
||||||
)
|
|
||||||
parser.add_argument("-p", "--project", help="name of new or path to existing project")
|
|
||||||
parser.add_argument(
|
|
||||||
"-d",
|
|
||||||
"--debug",
|
|
||||||
type=str2bool,
|
|
||||||
nargs="?",
|
|
||||||
const=True,
|
|
||||||
default=False,
|
|
||||||
help="enable post-mortem debugging",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-s",
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
default=23,
|
|
||||||
help="seed for seed_everything",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-f",
|
|
||||||
"--postfix",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="post-postfix for default name",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-l",
|
|
||||||
"--logdir",
|
|
||||||
type=str,
|
|
||||||
default="logs",
|
|
||||||
help="directory for logging dat shit",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--scale_lr",
|
|
||||||
type=str2bool,
|
|
||||||
nargs="?",
|
|
||||||
const=True,
|
|
||||||
default=True,
|
|
||||||
help="scale base-lr by ngpu * batch_size * n_accumulate",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--datadir_in_name",
|
|
||||||
type=str2bool,
|
|
||||||
nargs="?",
|
|
||||||
const=True,
|
|
||||||
default=True,
|
|
||||||
help="Prepend the final directory in the data_root to the output directory name",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--actual_resume",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Path to model to actually resume from",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--data_root",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to directory with training images",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--embedding_manager_ckpt",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Initialize embedding manager from a checkpoint",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--init_word",
|
|
||||||
type=str,
|
|
||||||
help="Word to use as source for initial token embedding.",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def nondefault_trainer_args(opt):
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser = Trainer.add_argparse_args(parser)
|
|
||||||
args = parser.parse_args([])
|
|
||||||
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
|
|
||||||
|
|
||||||
|
|
||||||
class WrappedDataset(Dataset):
|
|
||||||
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
|
||||||
|
|
||||||
def __init__(self, dataset):
|
|
||||||
self.data = dataset
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return self.data[idx]
|
|
||||||
|
|
||||||
|
|
||||||
def worker_init_fn(_):
|
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
|
||||||
|
|
||||||
dataset = worker_info.dataset
|
|
||||||
worker_id = worker_info.id
|
|
||||||
|
|
||||||
if isinstance(dataset, Txt2ImgIterableBaseDataset):
|
|
||||||
split_size = dataset.num_records // worker_info.num_workers
|
|
||||||
# reset num_records to the true number to retain reliable length information
|
|
||||||
dataset.sample_ids = dataset.valid_ids[worker_id * split_size : (worker_id + 1) * split_size]
|
|
||||||
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
|
||||||
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
|
||||||
else:
|
|
||||||
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
|
||||||
|
|
||||||
|
|
||||||
class DataModuleFromConfig(pl.LightningDataModule):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
batch_size,
|
|
||||||
train=None,
|
|
||||||
validation=None,
|
|
||||||
test=None,
|
|
||||||
predict=None,
|
|
||||||
wrap=False,
|
|
||||||
num_workers=None,
|
|
||||||
shuffle_test_loader=False,
|
|
||||||
use_worker_init_fn=False,
|
|
||||||
shuffle_val_dataloader=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.dataset_configs = dict()
|
|
||||||
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
|
||||||
self.use_worker_init_fn = use_worker_init_fn
|
|
||||||
if train is not None:
|
|
||||||
self.dataset_configs["train"] = train
|
|
||||||
self.train_dataloader = self._train_dataloader
|
|
||||||
if validation is not None:
|
|
||||||
self.dataset_configs["validation"] = validation
|
|
||||||
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
|
|
||||||
if test is not None:
|
|
||||||
self.dataset_configs["test"] = test
|
|
||||||
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
|
|
||||||
if predict is not None:
|
|
||||||
self.dataset_configs["predict"] = predict
|
|
||||||
self.predict_dataloader = self._predict_dataloader
|
|
||||||
self.wrap = wrap
|
|
||||||
|
|
||||||
def prepare_data(self):
|
|
||||||
for data_cfg in self.dataset_configs.values():
|
|
||||||
instantiate_from_config(data_cfg)
|
|
||||||
|
|
||||||
def setup(self, stage=None):
|
|
||||||
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
|
|
||||||
if self.wrap:
|
|
||||||
for k in self.datasets:
|
|
||||||
self.datasets[k] = WrappedDataset(self.datasets[k])
|
|
||||||
|
|
||||||
def _train_dataloader(self):
|
|
||||||
is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset)
|
|
||||||
if is_iterable_dataset or self.use_worker_init_fn:
|
|
||||||
init_fn = worker_init_fn
|
|
||||||
else:
|
|
||||||
init_fn = None
|
|
||||||
return DataLoader(
|
|
||||||
self.datasets["train"],
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
num_workers=self.num_workers,
|
|
||||||
shuffle=False if is_iterable_dataset else True,
|
|
||||||
worker_init_fn=init_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _val_dataloader(self, shuffle=False):
|
|
||||||
if isinstance(self.datasets["validation"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
|
||||||
init_fn = worker_init_fn
|
|
||||||
else:
|
|
||||||
init_fn = None
|
|
||||||
return DataLoader(
|
|
||||||
self.datasets["validation"],
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
num_workers=self.num_workers,
|
|
||||||
worker_init_fn=init_fn,
|
|
||||||
shuffle=shuffle,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _test_dataloader(self, shuffle=False):
|
|
||||||
is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset)
|
|
||||||
if is_iterable_dataset or self.use_worker_init_fn:
|
|
||||||
init_fn = worker_init_fn
|
|
||||||
else:
|
|
||||||
init_fn = None
|
|
||||||
|
|
||||||
# do not shuffle dataloader for iterable dataset
|
|
||||||
shuffle = shuffle and (not is_iterable_dataset)
|
|
||||||
|
|
||||||
return DataLoader(
|
|
||||||
self.datasets["test"],
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
num_workers=self.num_workers,
|
|
||||||
worker_init_fn=init_fn,
|
|
||||||
shuffle=shuffle,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _predict_dataloader(self, shuffle=False):
|
|
||||||
if isinstance(self.datasets["predict"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
|
||||||
init_fn = worker_init_fn
|
|
||||||
else:
|
|
||||||
init_fn = None
|
|
||||||
return DataLoader(
|
|
||||||
self.datasets["predict"],
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
num_workers=self.num_workers,
|
|
||||||
worker_init_fn=init_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SetupCallback(Callback):
|
|
||||||
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
|
||||||
super().__init__()
|
|
||||||
self.resume = resume
|
|
||||||
self.now = now
|
|
||||||
self.logdir = logdir
|
|
||||||
self.ckptdir = ckptdir
|
|
||||||
self.cfgdir = cfgdir
|
|
||||||
self.config = config
|
|
||||||
self.lightning_config = lightning_config
|
|
||||||
|
|
||||||
def on_keyboard_interrupt(self, trainer, pl_module):
|
|
||||||
if trainer.global_rank == 0:
|
|
||||||
print("Summoning checkpoint.")
|
|
||||||
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
|
||||||
trainer.save_checkpoint(ckpt_path)
|
|
||||||
|
|
||||||
def on_pretrain_routine_start(self, trainer, pl_module):
|
|
||||||
if trainer.global_rank == 0:
|
|
||||||
# Create logdirs and save configs
|
|
||||||
os.makedirs(self.logdir, exist_ok=True)
|
|
||||||
os.makedirs(self.ckptdir, exist_ok=True)
|
|
||||||
os.makedirs(self.cfgdir, exist_ok=True)
|
|
||||||
|
|
||||||
if "callbacks" in self.lightning_config:
|
|
||||||
if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]:
|
|
||||||
os.makedirs(
|
|
||||||
os.path.join(self.ckptdir, "trainstep_checkpoints"),
|
|
||||||
exist_ok=True,
|
|
||||||
)
|
|
||||||
print("Project config")
|
|
||||||
print(OmegaConf.to_yaml(self.config))
|
|
||||||
OmegaConf.save(
|
|
||||||
self.config,
|
|
||||||
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Lightning config")
|
|
||||||
print(OmegaConf.to_yaml(self.lightning_config))
|
|
||||||
OmegaConf.save(
|
|
||||||
OmegaConf.create({"lightning": self.lightning_config}),
|
|
||||||
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# ModelCheckpoint callback created log directory --- remove it
|
|
||||||
if not self.resume and os.path.exists(self.logdir):
|
|
||||||
dst, name = os.path.split(self.logdir)
|
|
||||||
dst = os.path.join(dst, "child_runs", name)
|
|
||||||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
|
||||||
try:
|
|
||||||
os.rename(self.logdir, dst)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ImageLogger(Callback):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
batch_frequency,
|
|
||||||
max_images,
|
|
||||||
clamp=True,
|
|
||||||
increase_log_steps=True,
|
|
||||||
rescale=True,
|
|
||||||
disabled=False,
|
|
||||||
log_on_batch_idx=False,
|
|
||||||
log_first_step=False,
|
|
||||||
log_images_kwargs=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.rescale = rescale
|
|
||||||
self.batch_freq = batch_frequency
|
|
||||||
self.max_images = max_images
|
|
||||||
self.logger_log_images = {}
|
|
||||||
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
|
||||||
if not increase_log_steps:
|
|
||||||
self.log_steps = [self.batch_freq]
|
|
||||||
self.clamp = clamp
|
|
||||||
self.disabled = disabled
|
|
||||||
self.log_on_batch_idx = log_on_batch_idx
|
|
||||||
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
|
||||||
self.log_first_step = log_first_step
|
|
||||||
|
|
||||||
@rank_zero_only
|
|
||||||
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
|
|
||||||
root = os.path.join(save_dir, "images", split)
|
|
||||||
for k in images:
|
|
||||||
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
|
||||||
if self.rescale:
|
|
||||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
|
||||||
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
|
||||||
grid = grid.numpy()
|
|
||||||
grid = (grid * 255).astype(np.uint8)
|
|
||||||
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
|
|
||||||
path = os.path.join(root, filename)
|
|
||||||
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
|
||||||
Image.fromarray(grid).save(path)
|
|
||||||
|
|
||||||
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
|
||||||
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
|
||||||
if (
|
|
||||||
self.check_frequency(check_idx)
|
|
||||||
and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
|
|
||||||
and callable(pl_module.log_images)
|
|
||||||
and self.max_images > 0
|
|
||||||
):
|
|
||||||
logger = type(pl_module.logger)
|
|
||||||
|
|
||||||
is_train = pl_module.training
|
|
||||||
if is_train:
|
|
||||||
pl_module.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
|
||||||
|
|
||||||
for k in images:
|
|
||||||
N = min(images[k].shape[0], self.max_images)
|
|
||||||
images[k] = images[k][:N]
|
|
||||||
if isinstance(images[k], torch.Tensor):
|
|
||||||
images[k] = images[k].detach().cpu()
|
|
||||||
if self.clamp:
|
|
||||||
images[k] = torch.clamp(images[k], -1.0, 1.0)
|
|
||||||
|
|
||||||
self.log_local(
|
|
||||||
pl_module.logger.save_dir,
|
|
||||||
split,
|
|
||||||
images,
|
|
||||||
pl_module.global_step,
|
|
||||||
pl_module.current_epoch,
|
|
||||||
batch_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
|
||||||
logger_log_images(pl_module, images, pl_module.global_step, split)
|
|
||||||
|
|
||||||
if is_train:
|
|
||||||
pl_module.train()
|
|
||||||
|
|
||||||
def check_frequency(self, check_idx):
|
|
||||||
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
|
|
||||||
check_idx > 0 or self.log_first_step
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
self.log_steps.pop(0)
|
|
||||||
except IndexError as e:
|
|
||||||
print(e)
|
|
||||||
pass
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None):
|
|
||||||
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
|
||||||
self.log_img(pl_module, batch, batch_idx, split="train")
|
|
||||||
|
|
||||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None):
|
|
||||||
if not self.disabled and pl_module.global_step > 0:
|
|
||||||
self.log_img(pl_module, batch, batch_idx, split="val")
|
|
||||||
if hasattr(pl_module, "calibrate_grad_norm"):
|
|
||||||
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
|
|
||||||
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
|
||||||
|
|
||||||
|
|
||||||
class CUDACallback(Callback):
|
|
||||||
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
|
||||||
def on_train_epoch_start(self, trainer, pl_module):
|
|
||||||
# Reset the memory use counter
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
|
|
||||||
torch.cuda.synchronize(trainer.root_gpu)
|
|
||||||
self.start_time = time.time()
|
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module, outputs=None):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize(trainer.root_gpu)
|
|
||||||
epoch_time = time.time() - self.start_time
|
|
||||||
|
|
||||||
try:
|
|
||||||
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
|
||||||
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
|
|
||||||
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
|
||||||
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ModeSwapCallback(Callback):
|
|
||||||
def __init__(self, swap_step=2000):
|
|
||||||
super().__init__()
|
|
||||||
self.is_frozen = False
|
|
||||||
self.swap_step = swap_step
|
|
||||||
|
|
||||||
def on_train_epoch_start(self, trainer, pl_module):
|
|
||||||
if trainer.global_step < self.swap_step and not self.is_frozen:
|
|
||||||
self.is_frozen = True
|
|
||||||
trainer.optimizers = [pl_module.configure_opt_embedding()]
|
|
||||||
|
|
||||||
if trainer.global_step > self.swap_step and self.is_frozen:
|
|
||||||
self.is_frozen = False
|
|
||||||
trainer.optimizers = [pl_module.configure_opt_model()]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# custom parser to specify config files, train, test and debug mode,
|
|
||||||
# postfix, resume.
|
|
||||||
# `--key value` arguments are interpreted as arguments to the trainer.
|
|
||||||
# `nested.key=value` arguments are interpreted as config parameters.
|
|
||||||
# configs are merged from left-to-right followed by command line parameters.
|
|
||||||
|
|
||||||
# model:
|
|
||||||
# base_learning_rate: float
|
|
||||||
# target: path to lightning module
|
|
||||||
# params:
|
|
||||||
# key: value
|
|
||||||
# data:
|
|
||||||
# target: main.DataModuleFromConfig
|
|
||||||
# params:
|
|
||||||
# batch_size: int
|
|
||||||
# wrap: bool
|
|
||||||
# train:
|
|
||||||
# target: path to train dataset
|
|
||||||
# params:
|
|
||||||
# key: value
|
|
||||||
# validation:
|
|
||||||
# target: path to validation dataset
|
|
||||||
# params:
|
|
||||||
# key: value
|
|
||||||
# test:
|
|
||||||
# target: path to test dataset
|
|
||||||
# params:
|
|
||||||
# key: value
|
|
||||||
# lightning: (optional, has sane defaults and can be specified on cmdline)
|
|
||||||
# trainer:
|
|
||||||
# additional arguments to trainer
|
|
||||||
# logger:
|
|
||||||
# logger to instantiate
|
|
||||||
# modelcheckpoint:
|
|
||||||
# modelcheckpoint to instantiate
|
|
||||||
# callbacks:
|
|
||||||
# callback1:
|
|
||||||
# target: importpath
|
|
||||||
# params:
|
|
||||||
# key: value
|
|
||||||
|
|
||||||
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
|
||||||
|
|
||||||
# add cwd for convenience and to make classes in this file available when
|
|
||||||
# running as `python main.py`
|
|
||||||
# (in particular `main.DataModuleFromConfig`)
|
|
||||||
sys.path.append(os.getcwd())
|
|
||||||
|
|
||||||
parser = get_parser()
|
|
||||||
parser = Trainer.add_argparse_args(parser)
|
|
||||||
|
|
||||||
opt, unknown = parser.parse_known_args()
|
|
||||||
if opt.name and opt.resume:
|
|
||||||
raise ValueError(
|
|
||||||
"-n/--name and -r/--resume cannot be specified both."
|
|
||||||
"If you want to resume training in a new log folder, "
|
|
||||||
"use -n/--name in combination with --resume_from_checkpoint"
|
|
||||||
)
|
|
||||||
if opt.resume:
|
|
||||||
if not os.path.exists(opt.resume):
|
|
||||||
raise ValueError("Cannot find {}".format(opt.resume))
|
|
||||||
if os.path.isfile(opt.resume):
|
|
||||||
paths = opt.resume.split("/")
|
|
||||||
# idx = len(paths)-paths[::-1].index("logs")+1
|
|
||||||
# logdir = "/".join(paths[:idx])
|
|
||||||
logdir = "/".join(paths[:-2])
|
|
||||||
ckpt = opt.resume
|
|
||||||
else:
|
|
||||||
assert os.path.isdir(opt.resume), opt.resume
|
|
||||||
logdir = opt.resume.rstrip("/")
|
|
||||||
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
|
||||||
|
|
||||||
opt.resume_from_checkpoint = ckpt
|
|
||||||
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
|
||||||
opt.base = base_configs + opt.base
|
|
||||||
_tmp = logdir.split("/")
|
|
||||||
nowname = _tmp[-1]
|
|
||||||
else:
|
|
||||||
if opt.name:
|
|
||||||
name = "_" + opt.name
|
|
||||||
elif opt.base:
|
|
||||||
cfg_fname = os.path.split(opt.base[0])[-1]
|
|
||||||
cfg_name = os.path.splitext(cfg_fname)[0]
|
|
||||||
name = "_" + cfg_name
|
|
||||||
else:
|
|
||||||
name = ""
|
|
||||||
|
|
||||||
if opt.datadir_in_name:
|
|
||||||
now = os.path.basename(os.path.normpath(opt.data_root)) + now
|
|
||||||
|
|
||||||
nowname = now + name + opt.postfix
|
|
||||||
logdir = os.path.join(opt.logdir, nowname)
|
|
||||||
|
|
||||||
ckptdir = os.path.join(logdir, "checkpoints")
|
|
||||||
cfgdir = os.path.join(logdir, "configs")
|
|
||||||
seed_everything(opt.seed)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# init and save configs
|
|
||||||
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
|
||||||
cli = OmegaConf.from_dotlist(unknown)
|
|
||||||
config = OmegaConf.merge(*configs, cli)
|
|
||||||
lightning_config = config.pop("lightning", OmegaConf.create())
|
|
||||||
# merge trainer cli with config
|
|
||||||
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
|
||||||
# default to ddp
|
|
||||||
trainer_config["accelerator"] = "auto"
|
|
||||||
for k in nondefault_trainer_args(opt):
|
|
||||||
trainer_config[k] = getattr(opt, k)
|
|
||||||
if "gpus" not in trainer_config:
|
|
||||||
del trainer_config["accelerator"]
|
|
||||||
cpu = True
|
|
||||||
else:
|
|
||||||
gpuinfo = trainer_config["gpus"]
|
|
||||||
print(f"Running on GPUs {gpuinfo}")
|
|
||||||
cpu = False
|
|
||||||
trainer_opt = argparse.Namespace(**trainer_config)
|
|
||||||
lightning_config.trainer = trainer_config
|
|
||||||
|
|
||||||
# model
|
|
||||||
|
|
||||||
# config.model.params.personalization_config.params.init_word = opt.init_word
|
|
||||||
config.model.params.personalization_config.params.embedding_manager_ckpt = opt.embedding_manager_ckpt
|
|
||||||
|
|
||||||
if opt.init_word:
|
|
||||||
config.model.params.personalization_config.params.initializer_words = [opt.init_word]
|
|
||||||
|
|
||||||
if opt.actual_resume:
|
|
||||||
model = load_model_from_config(config, opt.actual_resume)
|
|
||||||
else:
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
|
|
||||||
# trainer and callbacks
|
|
||||||
trainer_kwargs = dict()
|
|
||||||
|
|
||||||
# default logger configs
|
|
||||||
def_logger = "csv"
|
|
||||||
def_logger_target = "CSVLogger"
|
|
||||||
default_logger_cfgs = {
|
|
||||||
"wandb": {
|
|
||||||
"target": "pytorch_lightning.loggers.WandbLogger",
|
|
||||||
"params": {
|
|
||||||
"name": nowname,
|
|
||||||
"save_dir": logdir,
|
|
||||||
"offline": opt.debug,
|
|
||||||
"id": nowname,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
def_logger: {
|
|
||||||
"target": "pytorch_lightning.loggers." + def_logger_target,
|
|
||||||
"params": {
|
|
||||||
"name": def_logger,
|
|
||||||
"save_dir": logdir,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
default_logger_cfg = default_logger_cfgs[def_logger]
|
|
||||||
if "logger" in lightning_config:
|
|
||||||
logger_cfg = lightning_config.logger
|
|
||||||
else:
|
|
||||||
logger_cfg = OmegaConf.create()
|
|
||||||
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
|
||||||
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
|
||||||
|
|
||||||
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
|
||||||
# specify which metric is used to determine best models
|
|
||||||
default_modelckpt_cfg = {
|
|
||||||
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
|
||||||
"params": {
|
|
||||||
"dirpath": ckptdir,
|
|
||||||
"filename": "{epoch:06}",
|
|
||||||
"verbose": True,
|
|
||||||
"save_last": True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if hasattr(model, "monitor"):
|
|
||||||
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
|
||||||
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
|
||||||
default_modelckpt_cfg["params"]["save_top_k"] = 1
|
|
||||||
|
|
||||||
if "modelcheckpoint" in lightning_config:
|
|
||||||
modelckpt_cfg = lightning_config.modelcheckpoint
|
|
||||||
else:
|
|
||||||
modelckpt_cfg = OmegaConf.create()
|
|
||||||
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
|
||||||
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
|
||||||
if version.parse(pl.__version__) < version.parse("1.4.0"):
|
|
||||||
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
|
||||||
|
|
||||||
# add callback which sets up log directory
|
|
||||||
default_callbacks_cfg = {
|
|
||||||
"setup_callback": {
|
|
||||||
"target": "main.SetupCallback",
|
|
||||||
"params": {
|
|
||||||
"resume": opt.resume,
|
|
||||||
"now": now,
|
|
||||||
"logdir": logdir,
|
|
||||||
"ckptdir": ckptdir,
|
|
||||||
"cfgdir": cfgdir,
|
|
||||||
"config": config,
|
|
||||||
"lightning_config": lightning_config,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"image_logger": {
|
|
||||||
"target": "main.ImageLogger",
|
|
||||||
"params": {
|
|
||||||
"batch_frequency": 750,
|
|
||||||
"max_images": 4,
|
|
||||||
"clamp": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"learning_rate_logger": {
|
|
||||||
"target": "main.LearningRateMonitor",
|
|
||||||
"params": {
|
|
||||||
"logging_interval": "step",
|
|
||||||
# "log_momentum": True
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"cuda_callback": {"target": "main.CUDACallback"},
|
|
||||||
}
|
|
||||||
if version.parse(pl.__version__) >= version.parse("1.4.0"):
|
|
||||||
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
|
|
||||||
|
|
||||||
if "callbacks" in lightning_config:
|
|
||||||
callbacks_cfg = lightning_config.callbacks
|
|
||||||
else:
|
|
||||||
callbacks_cfg = OmegaConf.create()
|
|
||||||
|
|
||||||
if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
|
|
||||||
print(
|
|
||||||
"Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
|
|
||||||
)
|
|
||||||
default_metrics_over_trainsteps_ckpt_dict = {
|
|
||||||
"metrics_over_trainsteps_checkpoint": {
|
|
||||||
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
|
||||||
"params": {
|
|
||||||
"dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
|
|
||||||
"filename": "{epoch:06}-{step:09}",
|
|
||||||
"verbose": True,
|
|
||||||
"save_top_k": -1,
|
|
||||||
"every_n_train_steps": 10000,
|
|
||||||
"save_weights_only": True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
|
||||||
|
|
||||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
|
||||||
if "ignore_keys_callback" in callbacks_cfg and hasattr(trainer_opt, "resume_from_checkpoint"):
|
|
||||||
callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = trainer_opt.resume_from_checkpoint
|
|
||||||
elif "ignore_keys_callback" in callbacks_cfg:
|
|
||||||
del callbacks_cfg["ignore_keys_callback"]
|
|
||||||
|
|
||||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
|
||||||
trainer_kwargs["max_steps"] = trainer_opt.max_steps
|
|
||||||
|
|
||||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
||||||
trainer_opt.accelerator = "mps"
|
|
||||||
trainer_opt.detect_anomaly = False
|
|
||||||
|
|
||||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
|
||||||
trainer.logdir = logdir
|
|
||||||
|
|
||||||
# data
|
|
||||||
config.data.params.train.params.data_root = opt.data_root
|
|
||||||
config.data.params.validation.params.data_root = opt.data_root
|
|
||||||
data = instantiate_from_config(config.data)
|
|
||||||
|
|
||||||
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
|
||||||
# calling these ourselves should not be necessary but it is.
|
|
||||||
# lightning still takes care of proper multiprocessing though
|
|
||||||
data.prepare_data()
|
|
||||||
data.setup()
|
|
||||||
print("#### Data #####")
|
|
||||||
for k in data.datasets:
|
|
||||||
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
|
||||||
|
|
||||||
# configure learning rate
|
|
||||||
bs, base_lr = (
|
|
||||||
config.data.params.batch_size,
|
|
||||||
config.model.base_learning_rate,
|
|
||||||
)
|
|
||||||
if not cpu:
|
|
||||||
gpus = str(lightning_config.trainer.gpus).strip(", ").split(",")
|
|
||||||
ngpu = len(gpus)
|
|
||||||
else:
|
|
||||||
ngpu = 1
|
|
||||||
if "accumulate_grad_batches" in lightning_config.trainer:
|
|
||||||
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
|
||||||
else:
|
|
||||||
accumulate_grad_batches = 1
|
|
||||||
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
|
||||||
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
|
||||||
if opt.scale_lr:
|
|
||||||
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
|
||||||
print(
|
|
||||||
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
|
|
||||||
model.learning_rate,
|
|
||||||
accumulate_grad_batches,
|
|
||||||
ngpu,
|
|
||||||
bs,
|
|
||||||
base_lr,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model.learning_rate = base_lr
|
|
||||||
print("++++ NOT USING LR SCALING ++++")
|
|
||||||
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
|
||||||
|
|
||||||
# allow checkpointing via USR1
|
|
||||||
def melk(*args, **kwargs):
|
|
||||||
# run all checkpoint hooks
|
|
||||||
if trainer.global_rank == 0:
|
|
||||||
print("Summoning checkpoint.")
|
|
||||||
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
|
||||||
trainer.save_checkpoint(ckpt_path)
|
|
||||||
|
|
||||||
def divein(*args, **kwargs):
|
|
||||||
if trainer.global_rank == 0:
|
|
||||||
import pudb
|
|
||||||
|
|
||||||
pudb.set_trace()
|
|
||||||
|
|
||||||
import signal
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, melk)
|
|
||||||
signal.signal(signal.SIGTERM, divein)
|
|
||||||
|
|
||||||
# run
|
|
||||||
if opt.train:
|
|
||||||
try:
|
|
||||||
trainer.fit(model, data)
|
|
||||||
except Exception:
|
|
||||||
melk()
|
|
||||||
raise
|
|
||||||
if not opt.no_test and not trainer.interrupted:
|
|
||||||
trainer.test(model, data)
|
|
||||||
except Exception:
|
|
||||||
if opt.debug and trainer.global_rank == 0:
|
|
||||||
try:
|
|
||||||
import pudb as debugger
|
|
||||||
except ImportError:
|
|
||||||
import pdb as debugger
|
|
||||||
debugger.post_mortem()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
# move newly created debug project to debug_runs
|
|
||||||
if opt.debug and not opt.resume and trainer.global_rank == 0:
|
|
||||||
dst, name = os.path.split(logdir)
|
|
||||||
dst = os.path.join(dst, "debug_runs", name)
|
|
||||||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
|
||||||
os.rename(logdir, dst)
|
|
||||||
# if trainer.global_rank == 0:
|
|
||||||
# print(trainer.profiler.summary())
|
|
@ -1,130 +0,0 @@
|
|||||||
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
|
|
||||||
from ldm.modules.embedding_manager import EmbeddingManager
|
|
||||||
from ldm.invoke.globals import Globals
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def get_placeholder_loop(placeholder_string, embedder, use_bert):
|
|
||||||
new_placeholder = None
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if new_placeholder is None:
|
|
||||||
new_placeholder = input(
|
|
||||||
f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: "
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
new_placeholder = input(
|
|
||||||
f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: "
|
|
||||||
)
|
|
||||||
|
|
||||||
token = (
|
|
||||||
get_bert_token_for_string(embedder.tknz_fn, new_placeholder)
|
|
||||||
if use_bert
|
|
||||||
else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
|
|
||||||
)
|
|
||||||
|
|
||||||
if token is not None:
|
|
||||||
return new_placeholder, token
|
|
||||||
|
|
||||||
|
|
||||||
def get_clip_token_for_string(tokenizer, string):
|
|
||||||
batch_encoding = tokenizer(
|
|
||||||
string,
|
|
||||||
truncation=True,
|
|
||||||
max_length=77,
|
|
||||||
return_length=True,
|
|
||||||
return_overflowing_tokens=False,
|
|
||||||
padding="max_length",
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
tokens = batch_encoding["input_ids"]
|
|
||||||
|
|
||||||
if torch.count_nonzero(tokens - 49407) == 2:
|
|
||||||
return tokens[0, 1]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_bert_token_for_string(tokenizer, string):
|
|
||||||
token = tokenizer(string)
|
|
||||||
if torch.count_nonzero(token) == 3:
|
|
||||||
return token[0, 1]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--root_dir",
|
|
||||||
type=str,
|
|
||||||
default=".",
|
|
||||||
help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--manager_ckpts", type=str, nargs="+", required=True, help="Paths to a set of embedding managers to be merged."
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Output path for the merged manager",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"-sd",
|
|
||||||
"--use_bert",
|
|
||||||
action="store_true",
|
|
||||||
help="Flag to denote that we are not merging stable diffusion embeddings",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
Globals.root = args.root_dir
|
|
||||||
|
|
||||||
if args.use_bert:
|
|
||||||
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
|
|
||||||
else:
|
|
||||||
embedder = FrozenCLIPEmbedder().cuda()
|
|
||||||
|
|
||||||
EmbeddingManager = partial(EmbeddingManager, embedder, ["*"])
|
|
||||||
|
|
||||||
string_to_token_dict = {}
|
|
||||||
string_to_param_dict = torch.nn.ParameterDict()
|
|
||||||
|
|
||||||
placeholder_to_src = {}
|
|
||||||
|
|
||||||
for manager_ckpt in args.manager_ckpts:
|
|
||||||
print(f"Parsing {manager_ckpt}...")
|
|
||||||
|
|
||||||
manager = EmbeddingManager()
|
|
||||||
manager.load(manager_ckpt)
|
|
||||||
|
|
||||||
for placeholder_string in manager.string_to_token_dict:
|
|
||||||
if placeholder_string not in string_to_token_dict:
|
|
||||||
string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string]
|
|
||||||
string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string]
|
|
||||||
|
|
||||||
placeholder_to_src[placeholder_string] = manager_ckpt
|
|
||||||
else:
|
|
||||||
new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, use_bert=args.use_bert)
|
|
||||||
string_to_token_dict[new_placeholder] = new_token
|
|
||||||
string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]
|
|
||||||
|
|
||||||
placeholder_to_src[new_placeholder] = manager_ckpt
|
|
||||||
|
|
||||||
print("Saving combined manager...")
|
|
||||||
merged_manager = EmbeddingManager()
|
|
||||||
merged_manager.string_to_param_dict = string_to_param_dict
|
|
||||||
merged_manager.string_to_token_dict = string_to_token_dict
|
|
||||||
merged_manager.save(args.output_path)
|
|
||||||
|
|
||||||
print("Managers merged. Final list of placeholders: ")
|
|
||||||
print(placeholder_to_src)
|
|
@ -1,305 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import datetime
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import trange
|
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
|
|
||||||
|
|
||||||
def rescale(x: float) -> float:
|
|
||||||
return (x + 1.0) / 2.0
|
|
||||||
|
|
||||||
|
|
||||||
def custom_to_pil(x):
|
|
||||||
x = x.detach().cpu()
|
|
||||||
x = torch.clamp(x, -1.0, 1.0)
|
|
||||||
x = (x + 1.0) / 2.0
|
|
||||||
x = x.permute(1, 2, 0).numpy()
|
|
||||||
x = (255 * x).astype(np.uint8)
|
|
||||||
x = Image.fromarray(x)
|
|
||||||
if not x.mode == "RGB":
|
|
||||||
x = x.convert("RGB")
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def custom_to_np(x):
|
|
||||||
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
|
|
||||||
sample = x.detach().cpu()
|
|
||||||
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
|
|
||||||
sample = sample.permute(0, 2, 3, 1)
|
|
||||||
sample = sample.contiguous()
|
|
||||||
return sample
|
|
||||||
|
|
||||||
|
|
||||||
def logs2pil(logs, keys=["sample"]):
|
|
||||||
imgs = dict()
|
|
||||||
for k in logs:
|
|
||||||
try:
|
|
||||||
if len(logs[k].shape) == 4:
|
|
||||||
img = custom_to_pil(logs[k][0, ...])
|
|
||||||
elif len(logs[k].shape) == 3:
|
|
||||||
img = custom_to_pil(logs[k])
|
|
||||||
else:
|
|
||||||
print(f"Unknown format for key {k}. ")
|
|
||||||
img = None
|
|
||||||
except Exception:
|
|
||||||
img = None
|
|
||||||
imgs[k] = img
|
|
||||||
return imgs
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convsample(model, shape, return_intermediates=True, verbose=True, make_prog_row=False):
|
|
||||||
if not make_prog_row:
|
|
||||||
return model.p_sample_loop(None, shape, return_intermediates=return_intermediates, verbose=verbose)
|
|
||||||
else:
|
|
||||||
return model.progressive_denoising(None, shape, verbose=True)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convsample_ddim(model, steps, shape, eta=1.0):
|
|
||||||
ddim = DDIMSampler(model)
|
|
||||||
bs = shape[0]
|
|
||||||
shape = shape[1:]
|
|
||||||
samples, intermediates = ddim.sample(
|
|
||||||
steps,
|
|
||||||
batch_size=bs,
|
|
||||||
shape=shape,
|
|
||||||
eta=eta,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
return samples, intermediates
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def make_convolutional_sample(
|
|
||||||
model,
|
|
||||||
batch_size,
|
|
||||||
vanilla=False,
|
|
||||||
custom_steps=None,
|
|
||||||
eta=1.0,
|
|
||||||
):
|
|
||||||
log = dict()
|
|
||||||
|
|
||||||
shape = [
|
|
||||||
batch_size,
|
|
||||||
model.model.diffusion_model.in_channels,
|
|
||||||
model.model.diffusion_model.image_size,
|
|
||||||
model.model.diffusion_model.image_size,
|
|
||||||
]
|
|
||||||
|
|
||||||
with model.ema_scope("Plotting"):
|
|
||||||
t0 = time.time()
|
|
||||||
if vanilla:
|
|
||||||
sample, progrow = convsample(model, shape, make_prog_row=True)
|
|
||||||
else:
|
|
||||||
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, eta=eta)
|
|
||||||
|
|
||||||
t1 = time.time()
|
|
||||||
|
|
||||||
x_sample = model.decode_first_stage(sample)
|
|
||||||
|
|
||||||
log["sample"] = x_sample
|
|
||||||
log["time"] = t1 - t0
|
|
||||||
log["throughput"] = sample.shape[0] / (t1 - t0)
|
|
||||||
print(f'Throughput for this batch: {log["throughput"]}')
|
|
||||||
return log
|
|
||||||
|
|
||||||
|
|
||||||
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
|
|
||||||
if vanilla:
|
|
||||||
print(f"Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.")
|
|
||||||
else:
|
|
||||||
print(f"Using DDIM sampling with {custom_steps} sampling steps and eta={eta}")
|
|
||||||
|
|
||||||
tstart = time.time()
|
|
||||||
n_saved = len(glob.glob(os.path.join(logdir, "*.png"))) - 1
|
|
||||||
# path = logdir
|
|
||||||
if model.cond_stage_model is None:
|
|
||||||
all_images = []
|
|
||||||
|
|
||||||
print(f"Running unconditional sampling for {n_samples} samples")
|
|
||||||
for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
|
|
||||||
logs = make_convolutional_sample(
|
|
||||||
model, batch_size=batch_size, vanilla=vanilla, custom_steps=custom_steps, eta=eta
|
|
||||||
)
|
|
||||||
n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
|
|
||||||
all_images.extend([custom_to_np(logs["sample"])])
|
|
||||||
if n_saved >= n_samples:
|
|
||||||
print(f"Finish after generating {n_saved} samples")
|
|
||||||
break
|
|
||||||
all_img = np.concatenate(all_images, axis=0)
|
|
||||||
all_img = all_img[:n_samples]
|
|
||||||
shape_str = "x".join([str(x) for x in all_img.shape])
|
|
||||||
nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
|
|
||||||
np.savez(nppath, all_img)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Currently only sampling for unconditional models supported.")
|
|
||||||
|
|
||||||
print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
|
|
||||||
|
|
||||||
|
|
||||||
def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
|
|
||||||
for k in logs:
|
|
||||||
if k == key:
|
|
||||||
batch = logs[key]
|
|
||||||
if np_path is None:
|
|
||||||
for x in batch:
|
|
||||||
img = custom_to_pil(x)
|
|
||||||
imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
|
|
||||||
img.save(imgpath)
|
|
||||||
n_saved += 1
|
|
||||||
else:
|
|
||||||
npbatch = custom_to_np(batch)
|
|
||||||
shape_str = "x".join([str(x) for x in npbatch.shape])
|
|
||||||
nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
|
|
||||||
np.savez(nppath, npbatch)
|
|
||||||
n_saved += npbatch.shape[0]
|
|
||||||
return n_saved
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"-r",
|
|
||||||
"--resume",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
help="load from logdir or checkpoint in logdir",
|
|
||||||
)
|
|
||||||
parser.add_argument("-n", "--n_samples", type=int, nargs="?", help="number of samples to draw", default=50000)
|
|
||||||
parser.add_argument(
|
|
||||||
"-e",
|
|
||||||
"--eta",
|
|
||||||
type=float,
|
|
||||||
nargs="?",
|
|
||||||
help="eta for ddim sampling (0.0 yields deterministic sampling)",
|
|
||||||
default=1.0,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-v",
|
|
||||||
"--vanilla_sample",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="vanilla sampling (default option is DDIM sampling)?",
|
|
||||||
)
|
|
||||||
parser.add_argument("-l", "--logdir", type=str, nargs="?", help="extra logdir", default="none")
|
|
||||||
parser.add_argument(
|
|
||||||
"-c", "--custom_steps", type=int, nargs="?", help="number of steps for ddim and fastdpm sampling", default=50
|
|
||||||
)
|
|
||||||
parser.add_argument("--batch_size", type=int, nargs="?", help="the bs", default=10)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, sd):
|
|
||||||
model = instantiate_from_config(config)
|
|
||||||
model.load_state_dict(sd, strict=False)
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(config, ckpt, gpu, eval_mode):
|
|
||||||
if ckpt:
|
|
||||||
print(f"Loading model from {ckpt}")
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
global_step = pl_sd["global_step"]
|
|
||||||
else:
|
|
||||||
pl_sd = {"state_dict": None}
|
|
||||||
global_step = None
|
|
||||||
model = load_model_from_config(config.model, pl_sd["state_dict"])
|
|
||||||
|
|
||||||
return model, global_step
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
|
||||||
sys.path.append(os.getcwd())
|
|
||||||
command = " ".join(sys.argv)
|
|
||||||
|
|
||||||
parser = get_parser()
|
|
||||||
opt, unknown = parser.parse_known_args()
|
|
||||||
ckpt = None
|
|
||||||
|
|
||||||
if not os.path.exists(opt.resume):
|
|
||||||
raise ValueError("Cannot find {}".format(opt.resume))
|
|
||||||
if os.path.isfile(opt.resume):
|
|
||||||
# paths = opt.resume.split("/")
|
|
||||||
try:
|
|
||||||
logdir = "/".join(opt.resume.split("/")[:-1])
|
|
||||||
# idx = len(paths)-paths[::-1].index("logs")+1
|
|
||||||
print(f"Logdir is {logdir}")
|
|
||||||
except ValueError:
|
|
||||||
paths = opt.resume.split("/")
|
|
||||||
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
|
||||||
logdir = "/".join(paths[:idx])
|
|
||||||
ckpt = opt.resume
|
|
||||||
else:
|
|
||||||
assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
|
|
||||||
logdir = opt.resume.rstrip("/")
|
|
||||||
ckpt = os.path.join(logdir, "model.ckpt")
|
|
||||||
|
|
||||||
base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
|
|
||||||
opt.base = base_configs
|
|
||||||
|
|
||||||
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
|
||||||
cli = OmegaConf.from_dotlist(unknown)
|
|
||||||
config = OmegaConf.merge(*configs, cli)
|
|
||||||
|
|
||||||
gpu = True
|
|
||||||
eval_mode = True
|
|
||||||
|
|
||||||
if opt.logdir != "none":
|
|
||||||
locallog = logdir.split(os.sep)[-1]
|
|
||||||
if locallog == "":
|
|
||||||
locallog = logdir.split(os.sep)[-2]
|
|
||||||
print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
|
|
||||||
logdir = os.path.join(opt.logdir, locallog)
|
|
||||||
|
|
||||||
print(config)
|
|
||||||
|
|
||||||
model, global_step = load_model(config, ckpt, gpu, eval_mode)
|
|
||||||
print(f"global step: {global_step}")
|
|
||||||
print(75 * "=")
|
|
||||||
print("logging to:")
|
|
||||||
logdir = os.path.join(logdir, "samples", f"{global_step:08}", now)
|
|
||||||
imglogdir = os.path.join(logdir, "img")
|
|
||||||
numpylogdir = os.path.join(logdir, "numpy")
|
|
||||||
|
|
||||||
os.makedirs(imglogdir)
|
|
||||||
os.makedirs(numpylogdir)
|
|
||||||
print(logdir)
|
|
||||||
print(75 * "=")
|
|
||||||
|
|
||||||
# write config out
|
|
||||||
sampling_file = os.path.join(logdir, "sampling_config.yaml")
|
|
||||||
sampling_conf = vars(opt)
|
|
||||||
|
|
||||||
with open(sampling_file, "w") as f:
|
|
||||||
yaml.dump(sampling_conf, f, default_flow_style=False)
|
|
||||||
print(sampling_conf)
|
|
||||||
|
|
||||||
run(
|
|
||||||
model,
|
|
||||||
imglogdir,
|
|
||||||
eta=opt.eta,
|
|
||||||
vanilla=opt.vanilla_sample,
|
|
||||||
n_samples=opt.n_samples,
|
|
||||||
custom_steps=opt.custom_steps,
|
|
||||||
batch_size=opt.batch_size,
|
|
||||||
nplog=numpylogdir,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("done.")
|
|
@ -1,169 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
import scann
|
|
||||||
import argparse
|
|
||||||
import glob
|
|
||||||
from multiprocessing import cpu_count
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from ldm.util import parallel_data_prefetch
|
|
||||||
|
|
||||||
|
|
||||||
def search_bruteforce(searcher):
|
|
||||||
return searcher.score_brute_force().build()
|
|
||||||
|
|
||||||
|
|
||||||
def search_partioned_ah(
|
|
||||||
searcher, dims_per_block, aiq_threshold, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search
|
|
||||||
):
|
|
||||||
return (
|
|
||||||
searcher.tree(
|
|
||||||
num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=partioning_trainsize
|
|
||||||
)
|
|
||||||
.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold)
|
|
||||||
.reorder(reorder_k)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
|
|
||||||
return (
|
|
||||||
searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_datapool(dpath):
|
|
||||||
def load_single_file(saved_embeddings):
|
|
||||||
compressed = np.load(saved_embeddings)
|
|
||||||
database = {key: compressed[key] for key in compressed.files}
|
|
||||||
return database
|
|
||||||
|
|
||||||
def load_multi_files(data_archive):
|
|
||||||
database = {key: [] for key in data_archive[0].files}
|
|
||||||
for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
|
|
||||||
for key in d.files:
|
|
||||||
database[key].append(d[key])
|
|
||||||
|
|
||||||
return database
|
|
||||||
|
|
||||||
print(f'Load saved patch embedding from "{dpath}"')
|
|
||||||
file_content = glob.glob(os.path.join(dpath, "*.npz"))
|
|
||||||
|
|
||||||
if len(file_content) == 1:
|
|
||||||
data_pool = load_single_file(file_content[0])
|
|
||||||
elif len(file_content) > 1:
|
|
||||||
data = [np.load(f) for f in file_content]
|
|
||||||
prefetched_data = parallel_data_prefetch(
|
|
||||||
load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
|
|
||||||
)
|
|
||||||
|
|
||||||
data_pool = {
|
|
||||||
key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
|
|
||||||
|
|
||||||
print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
|
|
||||||
return data_pool
|
|
||||||
|
|
||||||
|
|
||||||
def train_searcher(
|
|
||||||
opt,
|
|
||||||
metric="dot_product",
|
|
||||||
partioning_trainsize=None,
|
|
||||||
reorder_k=None,
|
|
||||||
# todo tune
|
|
||||||
aiq_thld=0.2,
|
|
||||||
dims_per_block=2,
|
|
||||||
num_leaves=None,
|
|
||||||
num_leaves_to_search=None,
|
|
||||||
):
|
|
||||||
data_pool = load_datapool(opt.database)
|
|
||||||
k = opt.knn
|
|
||||||
|
|
||||||
if not reorder_k:
|
|
||||||
reorder_k = 2 * k
|
|
||||||
|
|
||||||
# normalize
|
|
||||||
# embeddings =
|
|
||||||
searcher = scann.scann_ops_pybind.builder(
|
|
||||||
data_pool["embedding"] / np.linalg.norm(data_pool["embedding"], axis=1)[:, np.newaxis], k, metric
|
|
||||||
)
|
|
||||||
pool_size = data_pool["embedding"].shape[0]
|
|
||||||
|
|
||||||
print(*(["#"] * 100))
|
|
||||||
print("Initializing scaNN searcher with the following values:")
|
|
||||||
print(f"k: {k}")
|
|
||||||
print(f"metric: {metric}")
|
|
||||||
print(f"reorder_k: {reorder_k}")
|
|
||||||
print(f"anisotropic_quantization_threshold: {aiq_thld}")
|
|
||||||
print(f"dims_per_block: {dims_per_block}")
|
|
||||||
print(*(["#"] * 100))
|
|
||||||
print("Start training searcher....")
|
|
||||||
print(f"N samples in pool is {pool_size}")
|
|
||||||
|
|
||||||
# this reflects the recommended design choices proposed at
|
|
||||||
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
|
|
||||||
if pool_size < 2e4:
|
|
||||||
print("Using brute force search.")
|
|
||||||
searcher = search_bruteforce(searcher)
|
|
||||||
elif 2e4 <= pool_size and pool_size < 1e5:
|
|
||||||
print("Using asymmetric hashing search and reordering.")
|
|
||||||
searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
|
|
||||||
else:
|
|
||||||
print("Using using partioning, asymmetric hashing search and reordering.")
|
|
||||||
|
|
||||||
if not partioning_trainsize:
|
|
||||||
partioning_trainsize = data_pool["embedding"].shape[0] // 10
|
|
||||||
if not num_leaves:
|
|
||||||
num_leaves = int(np.sqrt(pool_size))
|
|
||||||
|
|
||||||
if not num_leaves_to_search:
|
|
||||||
num_leaves_to_search = max(num_leaves // 20, 1)
|
|
||||||
|
|
||||||
print("Partitioning params:")
|
|
||||||
print(f"num_leaves: {num_leaves}")
|
|
||||||
print(f"num_leaves_to_search: {num_leaves_to_search}")
|
|
||||||
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
|
|
||||||
searcher = search_partioned_ah(
|
|
||||||
searcher, dims_per_block, aiq_thld, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Finish training searcher")
|
|
||||||
searcher_savedir = opt.target_path
|
|
||||||
os.makedirs(searcher_savedir, exist_ok=True)
|
|
||||||
searcher.serialize(searcher_savedir)
|
|
||||||
print(f'Saved trained searcher under "{searcher_savedir}"')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
sys.path.append(os.getcwd())
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--database",
|
|
||||||
"-d",
|
|
||||||
default="data/rdm/retrieval_databases/openimages",
|
|
||||||
type=str,
|
|
||||||
help="path to folder containing the clip feature of the database",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--target_path",
|
|
||||||
"-t",
|
|
||||||
default="data/rdm/searchers/openimages",
|
|
||||||
type=str,
|
|
||||||
help="path to the target folder where the searcher shall be stored.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--knn",
|
|
||||||
"-k",
|
|
||||||
default=20,
|
|
||||||
type=int,
|
|
||||||
help="number of nearest neighbors, for which the searcher shall be optimized",
|
|
||||||
)
|
|
||||||
|
|
||||||
opt, _ = parser.parse_known_args()
|
|
||||||
|
|
||||||
train_searcher(
|
|
||||||
opt,
|
|
||||||
)
|
|
@ -1,316 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
from itertools import islice
|
|
||||||
from einops import rearrange
|
|
||||||
from torchvision.utils import make_grid
|
|
||||||
from pytorch_lightning import seed_everything
|
|
||||||
from torch import autocast
|
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
import k_diffusion as K
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
|
||||||
from ldm.invoke.devices import choose_torch_device
|
|
||||||
|
|
||||||
|
|
||||||
def chunk(it, size):
|
|
||||||
it = iter(it)
|
|
||||||
return iter(lambda: tuple(islice(it, size)), ())
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
|
||||||
print(f"Loading model from {ckpt}")
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
if "global_step" in pl_sd:
|
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
if len(m) > 0 and verbose:
|
|
||||||
print("missing keys:")
|
|
||||||
print(m)
|
|
||||||
if len(u) > 0 and verbose:
|
|
||||||
print("unexpected keys:")
|
|
||||||
print(u)
|
|
||||||
|
|
||||||
model.to(choose_torch_device())
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
default="a painting of a virus monster playing guitar",
|
|
||||||
help="the prompt to render",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_grid",
|
|
||||||
action="store_true",
|
|
||||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_save",
|
|
||||||
action="store_true",
|
|
||||||
help="do not save individual samples. For speed measurements.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ddim_steps",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="number of ddim sampling steps",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--plms",
|
|
||||||
action="store_true",
|
|
||||||
help="use plms sampling",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--klms",
|
|
||||||
action="store_true",
|
|
||||||
help="use klms sampling",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--laion400m",
|
|
||||||
action="store_true",
|
|
||||||
help="uses the LAION400M model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fixed_code",
|
|
||||||
action="store_true",
|
|
||||||
help="if enabled, uses the same starting code across samples ",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ddim_eta",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_iter",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="sample this often",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--H",
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="image height, in pixel space",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--W",
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="image width, in pixel space",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--C",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="latent channels",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--f",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="downsampling factor",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_samples",
|
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="how many samples to produce for each given prompt. A.k.a. batch size",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_rows",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="rows in the grid (default: n_samples)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--scale",
|
|
||||||
type=float,
|
|
||||||
default=7.5,
|
|
||||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--from-file",
|
|
||||||
type=str,
|
|
||||||
help="if specified, load prompts from this file",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--config",
|
|
||||||
type=str,
|
|
||||||
default="configs/stable-diffusion/v1-inference.yaml",
|
|
||||||
help="path to config which constructs model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ckpt",
|
|
||||||
type=str,
|
|
||||||
default="models/ldm/stable-diffusion-v1/model.ckpt",
|
|
||||||
help="path to checkpoint of model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
default=42,
|
|
||||||
help="the seed (for reproducible sampling)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
|
|
||||||
)
|
|
||||||
opt = parser.parse_args()
|
|
||||||
|
|
||||||
if opt.laion400m:
|
|
||||||
print("Falling back to LAION 400M model...")
|
|
||||||
opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
|
|
||||||
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
|
|
||||||
opt.outdir = "outputs/txt2img-samples-laion400m"
|
|
||||||
|
|
||||||
config = OmegaConf.load(f"{opt.config}")
|
|
||||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
|
||||||
|
|
||||||
seed_everything(opt.seed)
|
|
||||||
|
|
||||||
device = torch.device(choose_torch_device())
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
# for klms
|
|
||||||
model_wrap = K.external.CompVisDenoiser(model)
|
|
||||||
|
|
||||||
class CFGDenoiser(nn.Module):
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__()
|
|
||||||
self.inner_model = model
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
|
||||||
x_in = torch.cat([x] * 2)
|
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
|
||||||
cond_in = torch.cat([uncond, cond])
|
|
||||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
|
||||||
return uncond + (cond - uncond) * cond_scale
|
|
||||||
|
|
||||||
if opt.plms:
|
|
||||||
sampler = PLMSSampler(model)
|
|
||||||
else:
|
|
||||||
sampler = DDIMSampler(model)
|
|
||||||
|
|
||||||
os.makedirs(opt.outdir, exist_ok=True)
|
|
||||||
outpath = opt.outdir
|
|
||||||
|
|
||||||
batch_size = opt.n_samples
|
|
||||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
|
||||||
if not opt.from_file:
|
|
||||||
prompt = opt.prompt
|
|
||||||
assert prompt is not None
|
|
||||||
data = [batch_size * [prompt]]
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f"reading prompts from {opt.from_file}")
|
|
||||||
with open(opt.from_file, "r") as f:
|
|
||||||
data = f.read().splitlines()
|
|
||||||
if len(data) >= batch_size:
|
|
||||||
data = list(chunk(data, batch_size))
|
|
||||||
else:
|
|
||||||
while len(data) < batch_size:
|
|
||||||
data.append(data[-1])
|
|
||||||
data = [data]
|
|
||||||
|
|
||||||
sample_path = os.path.join(outpath, "samples")
|
|
||||||
os.makedirs(sample_path, exist_ok=True)
|
|
||||||
base_count = len(os.listdir(sample_path))
|
|
||||||
grid_count = len(os.listdir(outpath)) - 1
|
|
||||||
|
|
||||||
start_code = None
|
|
||||||
if opt.fixed_code:
|
|
||||||
shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f]
|
|
||||||
if device.type == "mps":
|
|
||||||
start_code = torch.randn(shape, device="cpu").to(device)
|
|
||||||
else:
|
|
||||||
torch.randn(shape, device=device)
|
|
||||||
|
|
||||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
|
||||||
if device.type in ["mps", "cpu"]:
|
|
||||||
precision_scope = nullcontext # have to use f32 on mps
|
|
||||||
with torch.no_grad():
|
|
||||||
with precision_scope(device.type):
|
|
||||||
with model.ema_scope():
|
|
||||||
all_samples = list()
|
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
|
||||||
for prompts in tqdm(data, desc="data"):
|
|
||||||
uc = None
|
|
||||||
if opt.scale != 1.0:
|
|
||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
|
||||||
if isinstance(prompts, tuple):
|
|
||||||
prompts = list(prompts)
|
|
||||||
c = model.get_learned_conditioning(prompts)
|
|
||||||
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
|
||||||
|
|
||||||
if not opt.klms:
|
|
||||||
samples_ddim, _ = sampler.sample(
|
|
||||||
S=opt.ddim_steps,
|
|
||||||
conditioning=c,
|
|
||||||
batch_size=opt.n_samples,
|
|
||||||
shape=shape,
|
|
||||||
verbose=False,
|
|
||||||
unconditional_guidance_scale=opt.scale,
|
|
||||||
unconditional_conditioning=uc,
|
|
||||||
eta=opt.ddim_eta,
|
|
||||||
x_T=start_code,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sigmas = model_wrap.get_sigmas(opt.ddim_steps)
|
|
||||||
if start_code:
|
|
||||||
x = start_code
|
|
||||||
else:
|
|
||||||
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
|
|
||||||
model_wrap_cfg = CFGDenoiser(model_wrap)
|
|
||||||
extra_args = {"cond": c, "uncond": uc, "cond_scale": opt.scale}
|
|
||||||
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
|
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
|
|
||||||
if not opt.skip_save:
|
|
||||||
for x_sample in x_samples_ddim:
|
|
||||||
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
|
||||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
|
||||||
os.path.join(sample_path, f"{base_count:05}.png")
|
|
||||||
)
|
|
||||||
base_count += 1
|
|
||||||
|
|
||||||
if not opt.skip_grid:
|
|
||||||
all_samples.append(x_samples_ddim)
|
|
||||||
|
|
||||||
if not opt.skip_grid:
|
|
||||||
# additionally, save as grid
|
|
||||||
grid = torch.stack(all_samples, 0)
|
|
||||||
grid = rearrange(grid, "n b c h w -> (n b) c h w")
|
|
||||||
grid = make_grid(grid, nrow=n_rows)
|
|
||||||
|
|
||||||
# to image
|
|
||||||
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
|
|
||||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
|
||||||
grid_count += 1
|
|
||||||
|
|
||||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
Loading…
Reference in New Issue
Block a user