mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into development
This commit is contained in:
@ -6,6 +6,7 @@ import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
import traceback
|
||||
from tqdm import tqdm, trange
|
||||
from PIL import Image, ImageFilter
|
||||
from einops import rearrange, repeat
|
||||
@ -28,7 +29,8 @@ class Generator():
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
self.use_mps_noise = False
|
||||
self.use_mps_noise = False
|
||||
self.free_gpu_mem = None
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self,prompt,**kwargs):
|
||||
@ -43,14 +45,15 @@ class Generator():
|
||||
self.variation_amount = variation_amount
|
||||
self.with_variations = with_variations
|
||||
|
||||
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
**kwargs):
|
||||
scope = choose_autocast(self.precision)
|
||||
self.safety_checker = safety_checker
|
||||
make_image = self.get_make_image(
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler = sampler,
|
||||
init_image = init_image,
|
||||
width = width,
|
||||
height = height,
|
||||
@ -59,12 +62,14 @@ class Generator():
|
||||
perlin = perlin,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
results = []
|
||||
seed = seed if seed is not None else self.new_seed()
|
||||
first_seed = seed
|
||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||
with scope(self.model.device.type), self.model.ema_scope():
|
||||
|
||||
# There used to be an additional self.model.ema_scope() here, but it breaks
|
||||
# the inpaint-1.5 model. Not sure what it did.... ?
|
||||
with scope(self.model.device.type):
|
||||
for n in trange(iterations, desc='Generating'):
|
||||
x_T = None
|
||||
if self.variation_amount > 0:
|
||||
@ -79,7 +84,8 @@ class Generator():
|
||||
try:
|
||||
x_T = self.get_noise(width,height)
|
||||
except:
|
||||
pass
|
||||
print('** An error occurred while getting initial noise **')
|
||||
print(traceback.format_exc())
|
||||
|
||||
image = make_image(x_T)
|
||||
|
||||
@ -95,10 +101,10 @@ class Generator():
|
||||
|
||||
return results
|
||||
|
||||
def sample_to_image(self,samples):
|
||||
def sample_to_image(self,samples)->Image.Image:
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
Given samples returned from a sampler, converts
|
||||
it into a PIL Image
|
||||
"""
|
||||
x_samples = self.model.decode_first_stage(samples)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
@ -21,6 +21,7 @@ class Embiggen(Generator):
|
||||
def generate(self,prompt,iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None,
|
||||
**kwargs):
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
@ -63,6 +64,8 @@ class Embiggen(Generator):
|
||||
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models"
|
||||
|
||||
# Construct embiggen arg array, and sanity check arguments
|
||||
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
||||
embiggen = [1.0] # If not specified, assume no scaling
|
||||
|
@ -10,11 +10,12 @@ from PIL import Image
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # by get_noise()
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
||||
@ -29,7 +30,7 @@ class Img2Img(Generator):
|
||||
)
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
init_image = self._image_to_tensor(init_image.convert('RGB'))
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
@ -38,7 +39,7 @@ class Img2Img(Generator):
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c = conditioning
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
@ -55,7 +56,9 @@ class Img2Img(Generator):
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
init_latent = self.init_latent, # changes how noising is performed in ksampler
|
||||
init_latent = self.init_latent, # changes how noising is performed in ksampler
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
all_timesteps_count = steps
|
||||
)
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
@ -77,7 +80,10 @@ class Img2Img(Generator):
|
||||
|
||||
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
if len(image.shape) == 2: # 'L' image, as in a mask
|
||||
image = image[None,None]
|
||||
else: # 'RGB' image
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
if normalize:
|
||||
image = 2.0 * image - 1.0
|
||||
|
@ -2,12 +2,13 @@
|
||||
ldm.invoke.generator.inpaint descends from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
import cv2 as cv
|
||||
import PIL
|
||||
from PIL import Image, ImageFilter
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
from einops import rearrange, repeat
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
@ -24,11 +25,128 @@ class Inpaint(Img2Img):
|
||||
self.mask_blur_radius = 0
|
||||
super().__init__(model, precision)
|
||||
|
||||
# Outpaint support code
|
||||
def get_tile_images(self, image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False
|
||||
)
|
||||
|
||||
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image:
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = self.get_tile_images(a,*tile_size).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:,:,:,:,3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = (tiles_mask > 0)
|
||||
tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
rng = np.random.default_rng(seed = seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1,2)
|
||||
st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4]))
|
||||
si = Image.fromarray(st, mode='RGBA')
|
||||
|
||||
return si
|
||||
|
||||
|
||||
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
|
||||
npimg = np.asarray(mask, dtype=np.uint8)
|
||||
|
||||
# Detect any partially transparent regions
|
||||
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
||||
|
||||
# Detect hard edges
|
||||
npedge = cv.Canny(npimg, threshold1=100, threshold2=200)
|
||||
|
||||
# Combine
|
||||
npmask = npgradient + npedge
|
||||
|
||||
# Expand
|
||||
npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
|
||||
|
||||
new_mask = Image.fromarray(npmask)
|
||||
|
||||
if edge_blur > 0:
|
||||
new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur))
|
||||
|
||||
return ImageOps.invert(new_mask)
|
||||
|
||||
|
||||
def seam_paint(self,
|
||||
im: Image.Image,
|
||||
seam_size: int,
|
||||
seam_blur: int,
|
||||
prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,strength,
|
||||
noise
|
||||
) -> Image.Image:
|
||||
hard_mask = self.pil_image.split()[-1].copy()
|
||||
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
||||
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image = im.copy().convert('RGBA'),
|
||||
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
|
||||
strength = strength,
|
||||
mask_blur_radius = 0,
|
||||
seam_size = 0
|
||||
)
|
||||
|
||||
result = make_image(noise)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,mask_image,strength,
|
||||
mask_blur_radius: int = 8,
|
||||
step_callback=None,inpaint_replace=False, **kwargs):
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
tile_size: int = 32,
|
||||
step_callback=None,
|
||||
inpaint_replace=False, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and
|
||||
the initial image + mask. Return value depends on the seed at
|
||||
@ -37,7 +155,17 @@ class Inpaint(Img2Img):
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
self.pil_image = init_image
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
|
||||
# Fill missing areas of original image
|
||||
init_filled = self.tile_fill_missing(
|
||||
self.pil_image.copy(),
|
||||
seed = self.seed,
|
||||
tile_size = tile_size
|
||||
)
|
||||
init_filled.paste(init_image, (0,0), init_image.split()[-1])
|
||||
|
||||
# Create init tensor
|
||||
init_image = self._image_to_tensor(init_filled.convert('RGB'))
|
||||
|
||||
if isinstance(mask_image, PIL.Image.Image):
|
||||
self.pil_mask = mask_image
|
||||
@ -73,7 +201,8 @@ class Inpaint(Img2Img):
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c = conditioning
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
|
||||
print(f">> target t_enc is {t_enc} steps")
|
||||
|
||||
@ -105,38 +234,56 @@ class Inpaint(Img2Img):
|
||||
mask = mask_image,
|
||||
init_latent = self.init_latent
|
||||
)
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
result = self.sample_to_image(samples)
|
||||
|
||||
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
||||
if seam_size > 0:
|
||||
result = self.seam_paint(
|
||||
result,
|
||||
seam_size,
|
||||
seam_blur,
|
||||
prompt,
|
||||
sampler,
|
||||
seam_steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
seam_strength,
|
||||
x_T)
|
||||
|
||||
return result
|
||||
|
||||
return make_image
|
||||
|
||||
def sample_to_image(self, samples)->Image.Image:
|
||||
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||
|
||||
if self.pil_image is None or self.pil_mask is None:
|
||||
return gen_result
|
||||
|
||||
pil_mask = self.pil_mask
|
||||
pil_image = self.pil_image
|
||||
mask_blur_radius = self.mask_blur_radius
|
||||
|
||||
def color_correct(self, image: Image.Image, base_image: Image.Image, mask: Image.Image, mask_blur_radius: int) -> Image.Image:
|
||||
# Get the original alpha channel of the mask if there is one.
|
||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||
pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L')
|
||||
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||
pil_init_mask = mask.getchannel('A') if mask.mode == 'RGBA' else mask.convert('L')
|
||||
pil_init_image = base_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||
|
||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||
# Note that this doesn't use the mask, which would exclude some source image pixels from the
|
||||
# histogram and cause slight color changes.
|
||||
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
|
||||
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
|
||||
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
|
||||
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
|
||||
init_rgb_pixels = np.asarray(base_image.convert('RGB'), dtype=np.uint8)
|
||||
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8)
|
||||
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
|
||||
# Get numpy version
|
||||
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
|
||||
# Get numpy version of result
|
||||
np_image = np.asarray(image, dtype=np.uint8)
|
||||
|
||||
# Mask and calculate mean and standard deviation
|
||||
mask_pixels = init_a_pixels * init_mask_pixels > 0
|
||||
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
|
||||
np_image_masked = np_image[mask_pixels, :]
|
||||
|
||||
init_means = np_init_rgb_pixels_masked.mean(axis=0)
|
||||
init_std = np_init_rgb_pixels_masked.std(axis=0)
|
||||
gen_means = np_image_masked.mean(axis=0)
|
||||
gen_std = np_image_masked.std(axis=0)
|
||||
|
||||
# Color correct
|
||||
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
|
||||
np_matched_result = np_image.copy()
|
||||
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
|
||||
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
||||
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
@ -149,6 +296,16 @@ class Inpaint(Img2Img):
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
|
||||
matched_result.paste(base_image, (0,0), mask = blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
|
||||
def sample_to_image(self, samples)->Image.Image:
|
||||
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||
|
||||
if self.pil_image is None or self.pil_mask is None:
|
||||
return gen_result
|
||||
|
||||
corrected_result = self.color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
|
||||
|
||||
return corrected_result
|
||||
|
153
ldm/invoke/generator/omnibus.py
Normal file
153
ldm/invoke/generator/omnibus.py
Normal file
@ -0,0 +1,153 @@
|
||||
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
from PIL import Image, ImageOps
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.invoke.generator.txt2img import Txt2Img
|
||||
|
||||
class Omnibus(Img2Img,Txt2Img):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width,
|
||||
height,
|
||||
init_image = None,
|
||||
mask_image = None,
|
||||
strength = None,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
perlin=0.0,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
self.perlin = perlin
|
||||
num_samples = 1
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
if isinstance(init_image, Image.Image):
|
||||
if init_image.mode != 'RGB':
|
||||
init_image = init_image.convert('RGB')
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
|
||||
if isinstance(mask_image, Image.Image):
|
||||
mask_image = self._image_to_tensor(ImageOps.invert(mask_image).convert('L'),normalize=False)
|
||||
|
||||
t_enc = steps
|
||||
|
||||
if init_image is not None and mask_image is not None: # inpainting
|
||||
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
|
||||
|
||||
elif init_image is not None: # img2img
|
||||
scope = choose_autocast(self.precision)
|
||||
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
# create a completely black mask (1s)
|
||||
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
|
||||
# and the masked image is just a copy of the original
|
||||
masked_image = init_image
|
||||
|
||||
else: # txt2img
|
||||
init_image = torch.zeros(1, 3, height, width, device=self.model.device)
|
||||
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
|
||||
masked_image = init_image
|
||||
|
||||
self.init_latent = init_image
|
||||
height = init_image.shape[2]
|
||||
width = init_image.shape[3]
|
||||
model = self.model
|
||||
|
||||
def make_image(x_T):
|
||||
with torch.no_grad():
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
|
||||
batch = self.make_batch_sd(
|
||||
init_image,
|
||||
mask_image,
|
||||
masked_image,
|
||||
prompt=prompt,
|
||||
device=model.device,
|
||||
num_samples=num_samples,
|
||||
)
|
||||
|
||||
c = model.cond_stage_model.encode(batch["txt"])
|
||||
c_cat = list()
|
||||
for ck in model.concat_keys:
|
||||
cc = batch[ck].float()
|
||||
if ck != model.masked_image_key:
|
||||
bchw = [num_samples, 4, height//8, width//8]
|
||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||
else:
|
||||
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
|
||||
c_cat.append(cc)
|
||||
c_cat = torch.cat(c_cat, dim=1)
|
||||
|
||||
# cond
|
||||
cond={"c_concat": [c_cat], "c_crossattn": [c]}
|
||||
|
||||
# uncond cond
|
||||
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||
shape = [model.channels, height//8, width//8]
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
x_T = x_T,
|
||||
conditioning = cond,
|
||||
shape = shape,
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc_full,
|
||||
eta = 1.0,
|
||||
img_callback = step_callback,
|
||||
threshold = threshold,
|
||||
)
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to("cpu")
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
def make_batch_sd(
|
||||
self,
|
||||
image,
|
||||
mask,
|
||||
masked_image,
|
||||
prompt,
|
||||
device,
|
||||
num_samples=1):
|
||||
batch = {
|
||||
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"txt": num_samples * [prompt],
|
||||
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
}
|
||||
return batch
|
||||
|
||||
def get_noise(self, width:int, height:int):
|
||||
if self.init_latent is not None:
|
||||
height = self.init_latent.shape[2]
|
||||
width = self.init_latent.shape[3]
|
||||
return Txt2Img.get_noise(self,width,height)
|
@ -5,6 +5,8 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||
import torch
|
||||
import numpy as np
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@ -19,7 +21,7 @@ class Txt2Img(Generator):
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
uc, c = conditioning
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
@ -43,6 +45,7 @@ class Txt2Img(Generator):
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback,
|
||||
threshold = threshold,
|
||||
|
@ -5,9 +5,11 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
from ldm.invoke.generator.omnibus import Omnibus
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from PIL import Image
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@ -22,7 +24,7 @@ class Txt2Img2Img(Generator):
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
uc, c = conditioning
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
@ -59,7 +61,8 @@ class Txt2Img2Img(Generator):
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback
|
||||
img_callback = step_callback,
|
||||
extra_conditioning_info = extra_conditioning_info
|
||||
)
|
||||
|
||||
print(
|
||||
@ -93,6 +96,8 @@ class Txt2Img2Img(Generator):
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
all_timesteps_count=steps
|
||||
)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
@ -100,8 +105,49 @@ class Txt2Img2Img(Generator):
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
# in the case of the inpainting model being loaded, the trick of
|
||||
# providing an interpolated latent doesn't work, so we transiently
|
||||
# create a 512x512 PIL image, upscale it, and run the inpainting
|
||||
# over it in img2img mode. Because the inpaing model is so conservative
|
||||
# it doesn't change the image (much)
|
||||
def inpaint_make_image(x_T):
|
||||
omnibus = Omnibus(self.model,self.precision)
|
||||
result = omnibus.generate(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
width=init_width,
|
||||
height=init_height,
|
||||
step_callback=step_callback,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
ddim_eta = ddim_eta,
|
||||
conditioning = conditioning,
|
||||
**kwargs
|
||||
)
|
||||
assert result is not None and len(result)>0,'** txt2img failed **'
|
||||
image = result[0][0]
|
||||
interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
|
||||
print(kwargs.pop('init_image',None))
|
||||
result = omnibus.generate(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
init_image=interpolated_image,
|
||||
width=width,
|
||||
height=height,
|
||||
seed=result[0][1],
|
||||
step_callback=step_callback,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
ddim_eta = ddim_eta,
|
||||
conditioning = conditioning,
|
||||
**kwargs
|
||||
)
|
||||
return result[0][0]
|
||||
|
||||
if sampler.uses_inpainting_model():
|
||||
return inpaint_make_image
|
||||
else:
|
||||
return make_image
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height,scale = True):
|
||||
@ -129,3 +175,4 @@ class Txt2Img2Img(Generator):
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor],
|
||||
device=device)
|
||||
|
||||
|
Reference in New Issue
Block a user