2022-10-25 04:38:24 +00:00
|
|
|
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
2022-10-25 04:47:13 +00:00
|
|
|
from einops import repeat
|
2022-10-25 14:00:28 +00:00
|
|
|
from PIL import Image, ImageOps
|
2022-10-25 04:38:24 +00:00
|
|
|
from ldm.invoke.generator.base import downsampling
|
|
|
|
from ldm.invoke.generator.img2img import Img2Img
|
|
|
|
from ldm.invoke.generator.txt2img import Txt2Img
|
|
|
|
|
|
|
|
class Omnibus(Img2Img,Txt2Img):
|
|
|
|
def __init__(self, model, precision):
|
|
|
|
super().__init__(model, precision)
|
|
|
|
|
|
|
|
def get_make_image(
|
|
|
|
self,
|
|
|
|
prompt,
|
|
|
|
sampler,
|
|
|
|
steps,
|
|
|
|
cfg_scale,
|
|
|
|
ddim_eta,
|
|
|
|
conditioning,
|
|
|
|
width,
|
|
|
|
height,
|
|
|
|
init_image = None,
|
|
|
|
mask_image = None,
|
|
|
|
strength = None,
|
|
|
|
step_callback=None,
|
|
|
|
threshold=0.0,
|
|
|
|
perlin=0.0,
|
|
|
|
**kwargs):
|
|
|
|
"""
|
|
|
|
Returns a function returning an image derived from the prompt and the initial image
|
|
|
|
Return value depends on the seed at the time you call it.
|
|
|
|
"""
|
|
|
|
self.perlin = perlin
|
2022-10-25 04:47:13 +00:00
|
|
|
num_samples = 1
|
2022-10-25 04:38:24 +00:00
|
|
|
|
|
|
|
sampler.make_schedule(
|
|
|
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
|
|
|
)
|
|
|
|
|
|
|
|
if isinstance(init_image, Image.Image):
|
|
|
|
init_image = self._image_to_tensor(init_image)
|
|
|
|
|
|
|
|
if isinstance(mask_image, Image.Image):
|
2022-10-25 14:00:28 +00:00
|
|
|
mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
|
2022-10-25 04:38:24 +00:00
|
|
|
|
|
|
|
t_enc = steps
|
|
|
|
|
|
|
|
if init_image is not None and mask_image is not None: # inpainting
|
|
|
|
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
|
|
|
|
|
|
|
|
elif init_image is not None: # img2img
|
|
|
|
scope = choose_autocast(self.precision)
|
2022-10-25 14:00:28 +00:00
|
|
|
|
2022-10-25 04:38:24 +00:00
|
|
|
with scope(self.model.device.type):
|
|
|
|
self.init_latent = self.model.get_first_stage_encoding(
|
|
|
|
self.model.encode_first_stage(init_image)
|
|
|
|
) # move to latent space
|
2022-10-25 14:00:28 +00:00
|
|
|
|
2022-10-25 04:38:24 +00:00
|
|
|
# create a completely black mask (1s)
|
|
|
|
mask_image = torch.ones(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device)
|
|
|
|
# and the masked image is just a copy of the original
|
|
|
|
masked_image = init_image
|
|
|
|
t_enc = int(strength * steps)
|
|
|
|
|
|
|
|
else: # txt2img
|
2022-10-25 14:00:28 +00:00
|
|
|
init_image = torch.zeros(1, 3, width, height, device=self.model.device)
|
|
|
|
mask_image = torch.ones(1, 1, width, height, device=self.model.device)
|
|
|
|
masked_image = init_image
|
2022-10-25 04:38:24 +00:00
|
|
|
|
|
|
|
model = self.model
|
|
|
|
|
|
|
|
def make_image(x_T):
|
|
|
|
with torch.no_grad():
|
|
|
|
with torch.autocast("cuda"):
|
|
|
|
|
|
|
|
batch = self.make_batch_sd(
|
|
|
|
init_image,
|
|
|
|
mask_image,
|
|
|
|
masked_image,
|
|
|
|
prompt=prompt,
|
|
|
|
device=model.device,
|
2022-10-25 04:47:13 +00:00
|
|
|
num_samples=num_samples,
|
2022-10-25 04:38:24 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
c = model.cond_stage_model.encode(batch["txt"])
|
|
|
|
|
|
|
|
c_cat = list()
|
|
|
|
for ck in model.concat_keys:
|
|
|
|
cc = batch[ck].float()
|
|
|
|
if ck != model.masked_image_key:
|
2022-10-25 04:47:13 +00:00
|
|
|
bchw = [num_samples, 4, height//8, width//8]
|
2022-10-25 04:38:24 +00:00
|
|
|
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
|
|
|
else:
|
|
|
|
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
|
|
|
|
c_cat.append(cc)
|
|
|
|
c_cat = torch.cat(c_cat, dim=1)
|
|
|
|
|
|
|
|
# cond
|
|
|
|
cond={"c_concat": [c_cat], "c_crossattn": [c]}
|
|
|
|
|
|
|
|
# uncond cond
|
|
|
|
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
|
|
|
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
2022-10-25 04:47:13 +00:00
|
|
|
shape = [model.channels, height//8, width//8]
|
2022-10-25 14:00:28 +00:00
|
|
|
|
|
|
|
samples, _ = sampler.sample(
|
2022-10-25 04:38:24 +00:00
|
|
|
batch_size = 1,
|
|
|
|
S = t_enc,
|
|
|
|
x_T = x_T,
|
|
|
|
conditioning = cond,
|
|
|
|
shape = shape,
|
|
|
|
verbose = False,
|
|
|
|
unconditional_guidance_scale = cfg_scale,
|
|
|
|
unconditional_conditioning = uc_full,
|
|
|
|
eta = 1.0,
|
|
|
|
img_callback = step_callback,
|
|
|
|
threshold = threshold,
|
|
|
|
)
|
|
|
|
if self.free_gpu_mem:
|
|
|
|
self.model.model.to("cpu")
|
|
|
|
return self.sample_to_image(samples)
|
|
|
|
|
|
|
|
return make_image
|
|
|
|
|
|
|
|
def make_batch_sd(
|
2022-10-25 04:47:13 +00:00
|
|
|
self,
|
2022-10-25 04:38:24 +00:00
|
|
|
image,
|
|
|
|
mask,
|
|
|
|
masked_image,
|
|
|
|
prompt,
|
|
|
|
device,
|
|
|
|
num_samples=1):
|
|
|
|
batch = {
|
|
|
|
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
|
|
|
"txt": num_samples * [prompt],
|
|
|
|
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
|
|
|
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
|
|
|
}
|
2022-10-25 14:00:28 +00:00
|
|
|
print(f'DEBUG: image = {batch["image"]} shape={batch["image"].shape}')
|
|
|
|
print(f'DEBUG: mask = {batch["mask"]} shape={batch["mask"].shape}')
|
|
|
|
print(f'DEBUG: masked_image = {batch["masked_image"]} shape={batch["masked_image"].shape}')
|
2022-10-25 04:38:24 +00:00
|
|
|
return batch
|
|
|
|
|
|
|
|
def get_noise(self, width:int, height:int):
|
|
|
|
if self.init_latent:
|
|
|
|
print('DEBUG: returning Img2Img.getnoise()')
|
|
|
|
return super(Img2Img,self).get_noise(width,height)
|
|
|
|
else:
|
|
|
|
print('DEBUG: returning Txt2Img.getnoise()')
|
|
|
|
return super(Txt2Img,self).get_noise(width,height)
|