mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
inpaint and txt2img working with ddim sampler
This commit is contained in:
@ -77,7 +77,10 @@ class Img2Img(Generator):
|
||||
|
||||
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
if len(image.shape) == 2: # 'L' image, as in a mask
|
||||
image = image[None,None]
|
||||
else: # 'RGB' image
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
if normalize:
|
||||
image = 2.0 * image - 1.0
|
||||
|
@ -3,7 +3,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageOps
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.invoke.generator.txt2img import Txt2Img
|
||||
@ -44,7 +44,7 @@ class Omnibus(Img2Img,Txt2Img):
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
|
||||
if isinstance(mask_image, Image.Image):
|
||||
mask_image = self._image_to_tensor(mask_image,normalize=False)
|
||||
mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
|
||||
|
||||
t_enc = steps
|
||||
|
||||
@ -53,10 +53,12 @@ class Omnibus(Img2Img,Txt2Img):
|
||||
|
||||
elif init_image is not None: # img2img
|
||||
scope = choose_autocast(self.precision)
|
||||
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
# create a completely black mask (1s)
|
||||
mask_image = torch.ones(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device)
|
||||
# and the masked image is just a copy of the original
|
||||
@ -64,8 +66,9 @@ class Omnibus(Img2Img,Txt2Img):
|
||||
t_enc = int(strength * steps)
|
||||
|
||||
else: # txt2img
|
||||
mask_image = torch.zeros(init_image.shape[0], 3, init_image.width, init_image.height, device=self.model.device)
|
||||
masked_image = mask_image
|
||||
init_image = torch.zeros(1, 3, width, height, device=self.model.device)
|
||||
mask_image = torch.ones(1, 1, width, height, device=self.model.device)
|
||||
masked_image = init_image
|
||||
|
||||
model = self.model
|
||||
|
||||
@ -102,8 +105,8 @@ class Omnibus(Img2Img,Txt2Img):
|
||||
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||
shape = [model.channels, height//8, width//8]
|
||||
|
||||
samples, = sampler.sample(
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = t_enc,
|
||||
x_T = x_T,
|
||||
@ -136,6 +139,9 @@ class Omnibus(Img2Img,Txt2Img):
|
||||
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
}
|
||||
print(f'DEBUG: image = {batch["image"]} shape={batch["image"].shape}')
|
||||
print(f'DEBUG: mask = {batch["mask"]} shape={batch["mask"].shape}')
|
||||
print(f'DEBUG: masked_image = {batch["masked_image"]} shape={batch["masked_image"].shape}')
|
||||
return batch
|
||||
|
||||
def get_noise(self, width:int, height:int):
|
||||
|
Reference in New Issue
Block a user