2021-12-21 02:23:41 +00:00
|
|
|
from torchvision.datasets.utils import download_url
|
|
|
|
from ldm.util import instantiate_from_config
|
|
|
|
import torch
|
|
|
|
import os
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2021-12-21 02:23:41 +00:00
|
|
|
# todo ?
|
|
|
|
from google.colab import files
|
|
|
|
from IPython.display import Image as ipyimg
|
|
|
|
import ipywidgets as widgets
|
|
|
|
from PIL import Image
|
|
|
|
from einops import rearrange, repeat
|
2023-08-17 22:45:25 +00:00
|
|
|
import torchvision
|
2021-12-21 02:23:41 +00:00
|
|
|
from ldm.models.diffusion.ddim import DDIMSampler
|
|
|
|
from ldm.util import ismap
|
|
|
|
import time
|
|
|
|
from omegaconf import OmegaConf
|
2022-10-08 15:37:23 +00:00
|
|
|
from ldm.invoke.devices import choose_torch_device
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
def download_models(mode):
|
2021-12-21 02:23:41 +00:00
|
|
|
if mode == "superresolution":
|
|
|
|
# this is the small bsr light model
|
2023-07-27 14:54:01 +00:00
|
|
|
url_conf = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
|
|
|
url_ckpt = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
2021-12-21 02:23:41 +00:00
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
path_conf = "logs/diffusion/superresolution_bsr/configs/project.yaml"
|
|
|
|
path_ckpt = "logs/diffusion/superresolution_bsr/checkpoints/last.ckpt"
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
download_url(url_conf, path_conf)
|
|
|
|
download_url(url_ckpt, path_ckpt)
|
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
path_conf = path_conf + "/?dl=1" # fix it
|
|
|
|
path_ckpt = path_ckpt + "/?dl=1" # fix it
|
2021-12-21 02:23:41 +00:00
|
|
|
return path_conf, path_ckpt
|
|
|
|
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_from_config(config, ckpt):
|
|
|
|
print(f"Loading model from {ckpt}")
|
|
|
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
|
|
global_step = pl_sd["global_step"]
|
|
|
|
sd = pl_sd["state_dict"]
|
|
|
|
model = instantiate_from_config(config.model)
|
|
|
|
m, u = model.load_state_dict(sd, strict=False)
|
|
|
|
model.cuda()
|
|
|
|
model.eval()
|
|
|
|
return {"model": model}, global_step
|
|
|
|
|
|
|
|
|
|
|
|
def get_model(mode):
|
|
|
|
path_conf, path_ckpt = download_models(mode)
|
|
|
|
config = OmegaConf.load(path_conf)
|
|
|
|
model, step = load_model_from_config(config, path_ckpt)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def get_custom_cond(mode):
|
|
|
|
dest = "data/example_conditioning"
|
|
|
|
|
|
|
|
if mode == "superresolution":
|
|
|
|
uploaded_img = files.upload()
|
|
|
|
filename = next(iter(uploaded_img))
|
2023-07-27 14:54:01 +00:00
|
|
|
name, filetype = filename.split(".") # todo assumes just one dot in name !
|
2021-12-21 02:23:41 +00:00
|
|
|
os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}")
|
|
|
|
|
|
|
|
elif mode == "text_conditional":
|
2023-07-27 14:54:01 +00:00
|
|
|
w = widgets.Text(value="A cake with cream!", disabled=True)
|
2023-08-17 22:45:25 +00:00
|
|
|
display(w) # noqa: F821
|
2021-12-21 02:23:41 +00:00
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", "w") as f:
|
2021-12-21 02:23:41 +00:00
|
|
|
f.write(w.value)
|
|
|
|
|
|
|
|
elif mode == "class_conditional":
|
|
|
|
w = widgets.IntSlider(min=0, max=1000)
|
2023-08-17 22:45:25 +00:00
|
|
|
display(w) # noqa: F821
|
2023-07-27 14:54:01 +00:00
|
|
|
with open(f"{dest}/{mode}/custom.txt", "w") as f:
|
2021-12-21 02:23:41 +00:00
|
|
|
f.write(w.value)
|
|
|
|
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"cond not implemented for mode{mode}")
|
|
|
|
|
|
|
|
|
|
|
|
def get_cond_options(mode):
|
|
|
|
path = "data/example_conditioning"
|
|
|
|
path = os.path.join(path, mode)
|
|
|
|
onlyfiles = [f for f in sorted(os.listdir(path))]
|
|
|
|
return path, onlyfiles
|
|
|
|
|
|
|
|
|
|
|
|
def select_cond_path(mode):
|
|
|
|
path = "data/example_conditioning" # todo
|
|
|
|
path = os.path.join(path, mode)
|
|
|
|
onlyfiles = [f for f in sorted(os.listdir(path))]
|
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
selected = widgets.RadioButtons(options=onlyfiles, description="Select conditioning:", disabled=False)
|
2023-08-17 22:45:25 +00:00
|
|
|
display(selected) # noqa: F821
|
2021-12-21 02:23:41 +00:00
|
|
|
selected_path = os.path.join(path, selected.value)
|
|
|
|
return selected_path
|
|
|
|
|
|
|
|
|
|
|
|
def get_cond(mode, selected_path):
|
|
|
|
example = dict()
|
|
|
|
if mode == "superresolution":
|
|
|
|
up_f = 4
|
|
|
|
visualize_cond_img(selected_path)
|
|
|
|
|
|
|
|
c = Image.open(selected_path)
|
|
|
|
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
|
|
|
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)
|
2023-07-27 14:54:01 +00:00
|
|
|
c_up = rearrange(c_up, "1 c h w -> 1 h w c")
|
|
|
|
c = rearrange(c, "1 c h w -> 1 h w c")
|
|
|
|
c = 2.0 * c - 1.0
|
2021-12-21 02:23:41 +00:00
|
|
|
|
2022-08-31 04:33:23 +00:00
|
|
|
device = choose_torch_device()
|
|
|
|
c = c.to(device)
|
2021-12-21 02:23:41 +00:00
|
|
|
example["LR_image"] = c
|
|
|
|
example["image"] = c_up
|
|
|
|
|
|
|
|
return example
|
|
|
|
|
|
|
|
|
|
|
|
def visualize_cond_img(path):
|
2023-08-17 22:45:25 +00:00
|
|
|
display(ipyimg(filename=path)) # noqa: F821
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):
|
|
|
|
example = get_cond(task, selected_path)
|
|
|
|
|
|
|
|
save_intermediate_vid = False
|
|
|
|
n_runs = 1
|
|
|
|
masked = False
|
|
|
|
guider = None
|
|
|
|
ckwargs = None
|
2023-07-27 14:54:01 +00:00
|
|
|
mode = "ddim"
|
2021-12-21 02:23:41 +00:00
|
|
|
ddim_use_x0_pred = False
|
2023-07-27 14:54:01 +00:00
|
|
|
temperature = 1.0
|
|
|
|
eta = 1.0
|
2021-12-21 02:23:41 +00:00
|
|
|
make_progrow = True
|
|
|
|
custom_shape = None
|
|
|
|
|
|
|
|
height, width = example["image"].shape[1:3]
|
|
|
|
split_input = height >= 128 and width >= 128
|
|
|
|
|
|
|
|
if split_input:
|
|
|
|
ks = 128
|
|
|
|
stride = 64
|
|
|
|
vqf = 4 #
|
2023-07-27 14:54:01 +00:00
|
|
|
model.split_input_params = {
|
|
|
|
"ks": (ks, ks),
|
|
|
|
"stride": (stride, stride),
|
|
|
|
"vqf": vqf,
|
|
|
|
"patch_distributed_vq": True,
|
|
|
|
"tie_braker": False,
|
|
|
|
"clip_max_weight": 0.5,
|
|
|
|
"clip_min_weight": 0.01,
|
|
|
|
"clip_max_tie_weight": 0.5,
|
|
|
|
"clip_min_tie_weight": 0.01,
|
|
|
|
}
|
2021-12-21 02:23:41 +00:00
|
|
|
else:
|
|
|
|
if hasattr(model, "split_input_params"):
|
|
|
|
delattr(model, "split_input_params")
|
|
|
|
|
|
|
|
invert_mask = False
|
|
|
|
|
|
|
|
x_T = None
|
|
|
|
for n in range(n_runs):
|
|
|
|
if custom_shape is not None:
|
|
|
|
x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
2023-07-27 14:54:01 +00:00
|
|
|
x_T = repeat(x_T, "1 c h w -> b c h w", b=custom_shape[0])
|
|
|
|
|
|
|
|
logs = make_convolutional_sample(
|
|
|
|
example,
|
|
|
|
model,
|
|
|
|
mode=mode,
|
|
|
|
custom_steps=custom_steps,
|
|
|
|
eta=eta,
|
|
|
|
swap_mode=False,
|
|
|
|
masked=masked,
|
|
|
|
invert_mask=invert_mask,
|
|
|
|
quantize_x0=False,
|
|
|
|
custom_schedule=None,
|
|
|
|
decode_interval=10,
|
|
|
|
resize_enabled=resize_enabled,
|
|
|
|
custom_shape=custom_shape,
|
|
|
|
temperature=temperature,
|
|
|
|
noise_dropout=0.0,
|
|
|
|
corrector=guider,
|
|
|
|
corrector_kwargs=ckwargs,
|
|
|
|
x_T=x_T,
|
|
|
|
save_intermediate_vid=save_intermediate_vid,
|
|
|
|
make_progrow=make_progrow,
|
|
|
|
ddim_use_x0_pred=ddim_use_x0_pred,
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
return logs
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
2023-07-27 14:54:01 +00:00
|
|
|
def convsample_ddim(
|
|
|
|
model,
|
|
|
|
cond,
|
|
|
|
steps,
|
|
|
|
shape,
|
|
|
|
eta=1.0,
|
|
|
|
callback=None,
|
|
|
|
normals_sequence=None,
|
|
|
|
mask=None,
|
|
|
|
x0=None,
|
|
|
|
quantize_x0=False,
|
|
|
|
img_callback=None,
|
|
|
|
temperature=1.0,
|
|
|
|
noise_dropout=0.0,
|
|
|
|
score_corrector=None,
|
|
|
|
corrector_kwargs=None,
|
|
|
|
x_T=None,
|
|
|
|
log_every_t=None,
|
|
|
|
):
|
2021-12-21 02:23:41 +00:00
|
|
|
ddim = DDIMSampler(model)
|
|
|
|
bs = shape[0] # dont know where this comes from but wayne
|
|
|
|
shape = shape[1:] # cut batch dim
|
|
|
|
print(f"Sampling with eta = {eta}; steps: {steps}")
|
2023-07-27 14:54:01 +00:00
|
|
|
samples, intermediates = ddim.sample(
|
|
|
|
steps,
|
|
|
|
batch_size=bs,
|
|
|
|
shape=shape,
|
|
|
|
conditioning=cond,
|
|
|
|
callback=callback,
|
|
|
|
normals_sequence=normals_sequence,
|
|
|
|
quantize_x0=quantize_x0,
|
|
|
|
eta=eta,
|
|
|
|
mask=mask,
|
|
|
|
x0=x0,
|
|
|
|
temperature=temperature,
|
|
|
|
verbose=False,
|
|
|
|
score_corrector=score_corrector,
|
|
|
|
corrector_kwargs=corrector_kwargs,
|
|
|
|
x_T=x_T,
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
return samples, intermediates
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
2023-07-27 14:54:01 +00:00
|
|
|
def make_convolutional_sample(
|
|
|
|
batch,
|
|
|
|
model,
|
|
|
|
mode="vanilla",
|
|
|
|
custom_steps=None,
|
|
|
|
eta=1.0,
|
|
|
|
swap_mode=False,
|
|
|
|
masked=False,
|
|
|
|
invert_mask=True,
|
|
|
|
quantize_x0=False,
|
|
|
|
custom_schedule=None,
|
|
|
|
decode_interval=1000,
|
|
|
|
resize_enabled=False,
|
|
|
|
custom_shape=None,
|
|
|
|
temperature=1.0,
|
|
|
|
noise_dropout=0.0,
|
|
|
|
corrector=None,
|
|
|
|
corrector_kwargs=None,
|
|
|
|
x_T=None,
|
|
|
|
save_intermediate_vid=False,
|
|
|
|
make_progrow=True,
|
|
|
|
ddim_use_x0_pred=False,
|
|
|
|
):
|
2021-12-21 02:23:41 +00:00
|
|
|
log = dict()
|
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
z, c, x, xrec, xc = model.get_input(
|
|
|
|
batch,
|
|
|
|
model.first_stage_key,
|
|
|
|
return_first_stage_outputs=True,
|
|
|
|
force_c_encode=not (hasattr(model, "split_input_params") and model.cond_stage_key == "coordinates_bbox"),
|
|
|
|
return_original_cond=True,
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
log_every_t = 1 if save_intermediate_vid else None
|
|
|
|
|
|
|
|
if custom_shape is not None:
|
|
|
|
z = torch.randn(custom_shape)
|
|
|
|
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
|
|
|
|
|
|
|
z0 = None
|
|
|
|
|
|
|
|
log["input"] = x
|
|
|
|
log["reconstruction"] = xrec
|
|
|
|
|
|
|
|
if ismap(xc):
|
|
|
|
log["original_conditioning"] = model.to_rgb(xc)
|
2023-07-27 14:54:01 +00:00
|
|
|
if hasattr(model, "cond_stage_key"):
|
2021-12-21 02:23:41 +00:00
|
|
|
log[model.cond_stage_key] = model.to_rgb(xc)
|
|
|
|
|
|
|
|
else:
|
|
|
|
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
|
|
|
|
if model.cond_stage_model:
|
|
|
|
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
|
2023-07-27 14:54:01 +00:00
|
|
|
if model.cond_stage_key == "class_label":
|
2021-12-21 02:23:41 +00:00
|
|
|
log[model.cond_stage_key] = xc[model.cond_stage_key]
|
|
|
|
|
|
|
|
with model.ema_scope("Plotting"):
|
|
|
|
t0 = time.time()
|
|
|
|
img_cb = None
|
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
sample, intermediates = convsample_ddim(
|
|
|
|
model,
|
|
|
|
c,
|
|
|
|
steps=custom_steps,
|
|
|
|
shape=z.shape,
|
|
|
|
eta=eta,
|
|
|
|
quantize_x0=quantize_x0,
|
|
|
|
img_callback=img_cb,
|
|
|
|
mask=None,
|
|
|
|
x0=z0,
|
|
|
|
temperature=temperature,
|
|
|
|
noise_dropout=noise_dropout,
|
|
|
|
score_corrector=corrector,
|
|
|
|
corrector_kwargs=corrector_kwargs,
|
|
|
|
x_T=x_T,
|
|
|
|
log_every_t=log_every_t,
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
t1 = time.time()
|
|
|
|
|
|
|
|
if ddim_use_x0_pred:
|
2023-07-27 14:54:01 +00:00
|
|
|
sample = intermediates["pred_x0"][-1]
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
x_sample = model.decode_first_stage(sample)
|
|
|
|
|
|
|
|
try:
|
|
|
|
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
|
|
|
log["sample_noquant"] = x_sample_noquant
|
|
|
|
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
2023-08-17 22:45:25 +00:00
|
|
|
except Exception:
|
2021-12-21 02:23:41 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
log["sample"] = x_sample
|
|
|
|
log["time"] = t1 - t0
|
|
|
|
|
2022-08-31 04:33:23 +00:00
|
|
|
return log
|