mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
inpaint and txt2img working with ddim sampler
This commit is contained in:
parent
175c7bddfc
commit
aaf7a4f1d3
@ -655,6 +655,7 @@ class Generate:
|
||||
|
||||
return init_image,init_mask
|
||||
|
||||
# lots o' repeated code here! Turn into a make_func()
|
||||
def _make_base(self):
|
||||
if not self.generators.get('base'):
|
||||
from ldm.invoke.generator import Generator
|
||||
@ -665,6 +666,7 @@ class Generate:
|
||||
if not self.generators.get('img2img'):
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
self.generators['img2img'] = Img2Img(self.model, self.precision)
|
||||
self.generators['img2img'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['img2img']
|
||||
|
||||
def _make_embiggen(self):
|
||||
@ -693,10 +695,13 @@ class Generate:
|
||||
self.generators['inpaint'] = Inpaint(self.model, self.precision)
|
||||
return self.generators['inpaint']
|
||||
|
||||
# "omnibus" supports the runwayML custom inpainting model, which does
|
||||
# txt2img, img2img and inpainting using slight variations on the same code
|
||||
def _make_omnibus(self):
|
||||
if not self.generators.get('omnibus'):
|
||||
from ldm.invoke.generator.omnibus import Omnibus
|
||||
self.generators['omnibus'] = Omnibus(self.model, self.precision)
|
||||
self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem
|
||||
return self.generators['omnibus']
|
||||
|
||||
def load_model(self):
|
||||
|
@ -77,7 +77,10 @@ class Img2Img(Generator):
|
||||
|
||||
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
if len(image.shape) == 2: # 'L' image, as in a mask
|
||||
image = image[None,None]
|
||||
else: # 'RGB' image
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
if normalize:
|
||||
image = 2.0 * image - 1.0
|
||||
|
@ -3,7 +3,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageOps
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.invoke.generator.txt2img import Txt2Img
|
||||
@ -44,7 +44,7 @@ class Omnibus(Img2Img,Txt2Img):
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
|
||||
if isinstance(mask_image, Image.Image):
|
||||
mask_image = self._image_to_tensor(mask_image,normalize=False)
|
||||
mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
|
||||
|
||||
t_enc = steps
|
||||
|
||||
@ -53,10 +53,12 @@ class Omnibus(Img2Img,Txt2Img):
|
||||
|
||||
elif init_image is not None: # img2img
|
||||
scope = choose_autocast(self.precision)
|
||||
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
# create a completely black mask (1s)
|
||||
mask_image = torch.ones(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device)
|
||||
# and the masked image is just a copy of the original
|
||||
@ -64,8 +66,9 @@ class Omnibus(Img2Img,Txt2Img):
|
||||
t_enc = int(strength * steps)
|
||||
|
||||
else: # txt2img
|
||||
mask_image = torch.zeros(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device)
|
||||
masked_image = mask_image
|
||||
init_image = torch.zeros(1, 3, width, height, device=self.model.device)
|
||||
mask_image = torch.ones(1, 1, width, height, device=self.model.device)
|
||||
masked_image = init_image
|
||||
|
||||
model = self.model
|
||||
|
||||
@ -102,8 +105,8 @@ class Omnibus(Img2Img,Txt2Img):
|
||||
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||
shape = [model.channels, height//8, width//8]
|
||||
|
||||
samples, = sampler.sample(
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = t_enc,
|
||||
x_T = x_T,
|
||||
@ -136,6 +139,9 @@ class Omnibus(Img2Img,Txt2Img):
|
||||
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
}
|
||||
print(f'DEBUG: image = {batch["image"]} shape={batch["image"].shape}')
|
||||
print(f'DEBUG: mask = {batch["mask"]} shape={batch["mask"].shape}')
|
||||
print(f'DEBUG: masked_image = {batch["masked_image"]} shape={batch["masked_image"].shape}')
|
||||
return batch
|
||||
|
||||
def get_noise(self, width:int, height:int):
|
||||
|
@ -2157,7 +2157,6 @@ class DiffusionWrapper(pl.LightningModule):
|
||||
]
|
||||
|
||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
|
||||
print(f'DEBUG (ddpm) c_concat = {c_concat}')
|
||||
if self.conditioning_key is None:
|
||||
out = self.diffusion_model(x, t)
|
||||
elif self.conditioning_key == 'concat':
|
||||
|
Loading…
Reference in New Issue
Block a user