inpaint and txt2img working with ddim sampler

This commit is contained in:
Lincoln Stein 2022-10-25 10:00:28 -04:00
parent 175c7bddfc
commit aaf7a4f1d3
4 changed files with 21 additions and 8 deletions

View File

@ -655,6 +655,7 @@ class Generate:
return init_image,init_mask return init_image,init_mask
# lots o' repeated code here! Turn into a make_func()
def _make_base(self): def _make_base(self):
if not self.generators.get('base'): if not self.generators.get('base'):
from ldm.invoke.generator import Generator from ldm.invoke.generator import Generator
@ -665,6 +666,7 @@ class Generate:
if not self.generators.get('img2img'): if not self.generators.get('img2img'):
from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.generator.img2img import Img2Img
self.generators['img2img'] = Img2Img(self.model, self.precision) self.generators['img2img'] = Img2Img(self.model, self.precision)
self.generators['img2img'].free_gpu_mem = self.free_gpu_mem
return self.generators['img2img'] return self.generators['img2img']
def _make_embiggen(self): def _make_embiggen(self):
@ -693,10 +695,13 @@ class Generate:
self.generators['inpaint'] = Inpaint(self.model, self.precision) self.generators['inpaint'] = Inpaint(self.model, self.precision)
return self.generators['inpaint'] return self.generators['inpaint']
# "omnibus" supports the runwayML custom inpainting model, which does
# txt2img, img2img and inpainting using slight variations on the same code
def _make_omnibus(self): def _make_omnibus(self):
if not self.generators.get('omnibus'): if not self.generators.get('omnibus'):
from ldm.invoke.generator.omnibus import Omnibus from ldm.invoke.generator.omnibus import Omnibus
self.generators['omnibus'] = Omnibus(self.model, self.precision) self.generators['omnibus'] = Omnibus(self.model, self.precision)
self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem
return self.generators['omnibus'] return self.generators['omnibus']
def load_model(self): def load_model(self):

View File

@ -77,6 +77,9 @@ class Img2Img(Generator):
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor: def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
if len(image.shape) == 2: # 'L' image, as in a mask
image = image[None,None]
else: # 'RGB' image
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
if normalize: if normalize:

View File

@ -3,7 +3,7 @@
import torch import torch
import numpy as np import numpy as np
from einops import repeat from einops import repeat
from PIL import Image from PIL import Image, ImageOps
from ldm.invoke.generator.base import downsampling from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.generator.txt2img import Txt2Img from ldm.invoke.generator.txt2img import Txt2Img
@ -44,7 +44,7 @@ class Omnibus(Img2Img,Txt2Img):
init_image = self._image_to_tensor(init_image) init_image = self._image_to_tensor(init_image)
if isinstance(mask_image, Image.Image): if isinstance(mask_image, Image.Image):
mask_image = self._image_to_tensor(mask_image,normalize=False) mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
t_enc = steps t_enc = steps
@ -53,10 +53,12 @@ class Omnibus(Img2Img,Txt2Img):
elif init_image is not None: # img2img elif init_image is not None: # img2img
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
with scope(self.model.device.type): with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding( self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image) self.model.encode_first_stage(init_image)
) # move to latent space ) # move to latent space
# create a completely black mask (1s) # create a completely black mask (1s)
mask_image = torch.ones(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device) mask_image = torch.ones(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device)
# and the masked image is just a copy of the original # and the masked image is just a copy of the original
@ -64,8 +66,9 @@ class Omnibus(Img2Img,Txt2Img):
t_enc = int(strength * steps) t_enc = int(strength * steps)
else: # txt2img else: # txt2img
mask_image = torch.zeros(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device) init_image = torch.zeros(1, 3, width, height, device=self.model.device)
masked_image = mask_image mask_image = torch.ones(1, 1, width, height, device=self.model.device)
masked_image = init_image
model = self.model model = self.model
@ -103,7 +106,7 @@ class Omnibus(Img2Img,Txt2Img):
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape = [model.channels, height//8, width//8] shape = [model.channels, height//8, width//8]
samples, = sampler.sample( samples, _ = sampler.sample(
batch_size = 1, batch_size = 1,
S = t_enc, S = t_enc,
x_T = x_T, x_T = x_T,
@ -136,6 +139,9 @@ class Omnibus(Img2Img,Txt2Img):
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples), "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples), "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
} }
print(f'DEBUG: image = {batch["image"]} shape={batch["image"].shape}')
print(f'DEBUG: mask = {batch["mask"]} shape={batch["mask"].shape}')
print(f'DEBUG: masked_image = {batch["masked_image"]} shape={batch["masked_image"].shape}')
return batch return batch
def get_noise(self, width:int, height:int): def get_noise(self, width:int, height:int):

View File

@ -2157,7 +2157,6 @@ class DiffusionWrapper(pl.LightningModule):
] ]
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
print(f'DEBUG (ddpm) c_concat = {c_concat}')
if self.conditioning_key is None: if self.conditioning_key is None:
out = self.diffusion_model(x, t) out = self.diffusion_model(x, t)
elif self.conditioning_key == 'concat': elif self.conditioning_key == 'concat':