inpaint and txt2img working with ddim sampler

This commit is contained in:
Lincoln Stein
2022-10-25 10:00:28 -04:00
parent 175c7bddfc
commit aaf7a4f1d3
4 changed files with 21 additions and 8 deletions

View File

@ -77,7 +77,10 @@ class Img2Img(Generator):
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
if len(image.shape) == 2: # 'L' image, as in a mask
image = image[None,None]
else: # 'RGB' image
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
if normalize:
image = 2.0 * image - 1.0

View File

@ -3,7 +3,7 @@
import torch
import numpy as np
from einops import repeat
from PIL import Image
from PIL import Image, ImageOps
from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.generator.txt2img import Txt2Img
@ -44,7 +44,7 @@ class Omnibus(Img2Img,Txt2Img):
init_image = self._image_to_tensor(init_image)
if isinstance(mask_image, Image.Image):
mask_image = self._image_to_tensor(mask_image,normalize=False)
mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
t_enc = steps
@ -53,10 +53,12 @@ class Omnibus(Img2Img,Txt2Img):
elif init_image is not None: # img2img
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
# create a completely black mask (1s)
mask_image = torch.ones(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device)
# and the masked image is just a copy of the original
@ -64,8 +66,9 @@ class Omnibus(Img2Img,Txt2Img):
t_enc = int(strength * steps)
else: # txt2img
mask_image = torch.zeros(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device)
masked_image = mask_image
init_image = torch.zeros(1, 3, width, height, device=self.model.device)
mask_image = torch.ones(1, 1, width, height, device=self.model.device)
masked_image = init_image
model = self.model
@ -102,8 +105,8 @@ class Omnibus(Img2Img,Txt2Img):
uc_cross = model.get_unconditional_conditioning(num_samples, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape = [model.channels, height//8, width//8]
samples, = sampler.sample(
samples, _ = sampler.sample(
batch_size = 1,
S = t_enc,
x_T = x_T,
@ -136,6 +139,9 @@ class Omnibus(Img2Img,Txt2Img):
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
}
print(f'DEBUG: image = {batch["image"]} shape={batch["image"].shape}')
print(f'DEBUG: mask = {batch["mask"]} shape={batch["mask"].shape}')
print(f'DEBUG: masked_image = {batch["masked_image"]} shape={batch["masked_image"].shape}')
return batch
def get_noise(self, width:int, height:int):