inpaint and txt2img working with ddim sampler

This commit is contained in:
Lincoln Stein 2022-10-25 10:00:28 -04:00
parent 175c7bddfc
commit aaf7a4f1d3
4 changed files with 21 additions and 8 deletions

View File

@ -655,6 +655,7 @@ class Generate:
return init_image,init_mask
# lots o' repeated code here! Turn into a make_func()
def _make_base(self):
if not self.generators.get('base'):
from ldm.invoke.generator import Generator
@ -665,6 +666,7 @@ class Generate:
if not self.generators.get('img2img'):
from ldm.invoke.generator.img2img import Img2Img
self.generators['img2img'] = Img2Img(self.model, self.precision)
self.generators['img2img'].free_gpu_mem = self.free_gpu_mem
return self.generators['img2img']
def _make_embiggen(self):
@ -693,10 +695,13 @@ class Generate:
self.generators['inpaint'] = Inpaint(self.model, self.precision)
return self.generators['inpaint']
# "omnibus" supports the runwayML custom inpainting model, which does
# txt2img, img2img and inpainting using slight variations on the same code
def _make_omnibus(self):
if not self.generators.get('omnibus'):
from ldm.invoke.generator.omnibus import Omnibus
self.generators['omnibus'] = Omnibus(self.model, self.precision)
self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem
return self.generators['omnibus']
def load_model(self):

View File

@ -77,7 +77,10 @@ class Img2Img(Generator):
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
if len(image.shape) == 2: # 'L' image, as in a mask
image = image[None,None]
else: # 'RGB' image
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
if normalize:
image = 2.0 * image - 1.0

View File

@ -3,7 +3,7 @@
import torch
import numpy as np
from einops import repeat
from PIL import Image
from PIL import Image, ImageOps
from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.generator.txt2img import Txt2Img
@ -44,7 +44,7 @@ class Omnibus(Img2Img,Txt2Img):
init_image = self._image_to_tensor(init_image)
if isinstance(mask_image, Image.Image):
mask_image = self._image_to_tensor(mask_image,normalize=False)
mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
t_enc = steps
@ -53,10 +53,12 @@ class Omnibus(Img2Img,Txt2Img):
elif init_image is not None: # img2img
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
# create a completely black mask (1s)
mask_image = torch.ones(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device)
# and the masked image is just a copy of the original
@ -64,8 +66,9 @@ class Omnibus(Img2Img,Txt2Img):
t_enc = int(strength * steps)
else: # txt2img
mask_image = torch.zeros(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device)
masked_image = mask_image
init_image = torch.zeros(1, 3, width, height, device=self.model.device)
mask_image = torch.ones(1, 1, width, height, device=self.model.device)
masked_image = init_image
model = self.model
@ -102,8 +105,8 @@ class Omnibus(Img2Img,Txt2Img):
uc_cross = model.get_unconditional_conditioning(num_samples, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape = [model.channels, height//8, width//8]
samples, = sampler.sample(
samples, _ = sampler.sample(
batch_size = 1,
S = t_enc,
x_T = x_T,
@ -136,6 +139,9 @@ class Omnibus(Img2Img,Txt2Img):
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
}
print(f'DEBUG: image = {batch["image"]} shape={batch["image"].shape}')
print(f'DEBUG: mask = {batch["mask"]} shape={batch["mask"].shape}')
print(f'DEBUG: masked_image = {batch["masked_image"]} shape={batch["masked_image"].shape}')
return batch
def get_noise(self, width:int, height:int):

View File

@ -2157,7 +2157,6 @@ class DiffusionWrapper(pl.LightningModule):
]
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
print(f'DEBUG (ddpm) c_concat = {c_concat}')
if self.conditioning_key is None:
out = self.diffusion_model(x, t)
elif self.conditioning_key == 'concat':