mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
308 lines
13 KiB
Python
308 lines
13 KiB
Python
'''
|
|
Base class for ldm.invoke.generator.*
|
|
including img2img, txt2img, and inpaint
|
|
'''
|
|
import torch
|
|
import numpy as np
|
|
import random
|
|
import os
|
|
import traceback
|
|
from tqdm import tqdm, trange
|
|
from PIL import Image, ImageFilter, ImageChops
|
|
import cv2 as cv
|
|
from einops import rearrange, repeat
|
|
from pytorch_lightning import seed_everything
|
|
from ldm.invoke.devices import choose_autocast
|
|
from ldm.util import rand_perlin_2d
|
|
|
|
downsampling = 8
|
|
CAUTION_IMG = 'assets/caution.png'
|
|
|
|
class Generator():
|
|
def __init__(self, model, precision):
|
|
self.model = model
|
|
self.precision = precision
|
|
self.seed = None
|
|
self.latent_channels = model.channels
|
|
self.downsampling_factor = downsampling # BUG: should come from model or config
|
|
self.safety_checker = None
|
|
self.perlin = 0.0
|
|
self.threshold = 0
|
|
self.variation_amount = 0
|
|
self.with_variations = []
|
|
self.use_mps_noise = False
|
|
self.free_gpu_mem = None
|
|
|
|
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
|
def get_make_image(self,prompt,**kwargs):
|
|
"""
|
|
Returns a function returning an image derived from the prompt and the initial image
|
|
Return value depends on the seed at the time you call it
|
|
"""
|
|
raise NotImplementedError("image_iterator() must be implemented in a descendent class")
|
|
|
|
def set_variation(self, seed, variation_amount, with_variations):
|
|
self.seed = seed
|
|
self.variation_amount = variation_amount
|
|
self.with_variations = with_variations
|
|
|
|
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
|
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
|
safety_checker:dict=None,
|
|
**kwargs):
|
|
scope = choose_autocast(self.precision)
|
|
self.safety_checker = safety_checker
|
|
make_image = self.get_make_image(
|
|
prompt,
|
|
sampler = sampler,
|
|
init_image = init_image,
|
|
width = width,
|
|
height = height,
|
|
step_callback = step_callback,
|
|
threshold = threshold,
|
|
perlin = perlin,
|
|
**kwargs
|
|
)
|
|
results = []
|
|
seed = seed if seed is not None and seed >= 0 else self.new_seed()
|
|
first_seed = seed
|
|
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
|
|
|
# There used to be an additional self.model.ema_scope() here, but it breaks
|
|
# the inpaint-1.5 model. Not sure what it did.... ?
|
|
with scope(self.model.device.type):
|
|
for n in trange(iterations, desc='Generating'):
|
|
x_T = None
|
|
if self.variation_amount > 0:
|
|
seed_everything(seed)
|
|
target_noise = self.get_noise(width,height)
|
|
x_T = self.slerp(self.variation_amount, initial_noise, target_noise)
|
|
elif initial_noise is not None:
|
|
# i.e. we specified particular variations
|
|
x_T = initial_noise
|
|
else:
|
|
seed_everything(seed)
|
|
try:
|
|
x_T = self.get_noise(width,height)
|
|
except:
|
|
print('** An error occurred while getting initial noise **')
|
|
print(traceback.format_exc())
|
|
|
|
image = make_image(x_T)
|
|
|
|
if self.safety_checker is not None:
|
|
image = self.safety_check(image)
|
|
|
|
results.append([image, seed])
|
|
|
|
if image_callback is not None:
|
|
image_callback(image, seed, first_seed=first_seed)
|
|
|
|
seed = self.new_seed()
|
|
|
|
return results
|
|
|
|
def sample_to_image(self,samples)->Image.Image:
|
|
"""
|
|
Given samples returned from a sampler, converts
|
|
it into a PIL Image
|
|
"""
|
|
x_samples = self.model.decode_first_stage(samples)
|
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
|
if len(x_samples) != 1:
|
|
raise Exception(
|
|
f'>> expected to get a single image, but got {len(x_samples)}')
|
|
x_sample = 255.0 * rearrange(
|
|
x_samples[0].cpu().numpy(), 'c h w -> h w c'
|
|
)
|
|
return Image.fromarray(x_sample.astype(np.uint8))
|
|
|
|
# write an approximate RGB image from latent samples for a single step to PNG
|
|
|
|
def repaste_and_color_correct(self, result: Image.Image, init_image: Image.Image, init_mask: Image.Image, mask_blur_radius: int = 8) -> Image.Image:
|
|
if init_image is None or init_mask is None:
|
|
return result
|
|
|
|
# Get the original alpha channel of the mask if there is one.
|
|
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
|
pil_init_mask = init_mask.getchannel('A') if init_mask.mode == 'RGBA' else init_mask.convert('L')
|
|
pil_init_image = init_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
|
|
|
# Build an image with only visible pixels from source to use as reference for color-matching.
|
|
init_rgb_pixels = np.asarray(init_image.convert('RGB'), dtype=np.uint8)
|
|
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8)
|
|
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
|
|
|
|
# Get numpy version of result
|
|
np_image = np.asarray(result, dtype=np.uint8)
|
|
|
|
# Mask and calculate mean and standard deviation
|
|
mask_pixels = init_a_pixels * init_mask_pixels > 0
|
|
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
|
|
np_image_masked = np_image[mask_pixels, :]
|
|
|
|
init_means = np_init_rgb_pixels_masked.mean(axis=0)
|
|
init_std = np_init_rgb_pixels_masked.std(axis=0)
|
|
gen_means = np_image_masked.mean(axis=0)
|
|
gen_std = np_image_masked.std(axis=0)
|
|
|
|
# Color correct
|
|
np_matched_result = np_image.copy()
|
|
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
|
|
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
|
|
|
# Blur the mask out (into init image) by specified amount
|
|
if mask_blur_radius > 0:
|
|
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
|
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
|
pmd = Image.fromarray(nmd, mode='L')
|
|
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
|
else:
|
|
blurred_init_mask = pil_init_mask
|
|
|
|
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
|
|
|
|
# Paste original on color-corrected generation (using blurred mask)
|
|
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask)
|
|
return matched_result
|
|
|
|
|
|
|
|
def sample_to_lowres_estimated_image(self,samples):
|
|
# origingally adapted from code by @erucipe and @keturn here:
|
|
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
|
|
|
# these updated numbers for v1.5 are from @torridgristle
|
|
v1_5_latent_rgb_factors = torch.tensor([
|
|
# R G B
|
|
[ 0.3444, 0.1385, 0.0670], # L1
|
|
[ 0.1247, 0.4027, 0.1494], # L2
|
|
[-0.3192, 0.2513, 0.2103], # L3
|
|
[-0.1307, -0.1874, -0.7445] # L4
|
|
], dtype=samples.dtype, device=samples.device)
|
|
|
|
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
|
latents_ubyte = (((latent_image + 1) / 2)
|
|
.clamp(0, 1) # change scale from -1..1 to 0..1
|
|
.mul(0xFF) # to 0..255
|
|
.byte()).cpu()
|
|
|
|
return Image.fromarray(latents_ubyte.numpy())
|
|
|
|
def generate_initial_noise(self, seed, width, height):
|
|
initial_noise = None
|
|
if self.variation_amount > 0 or len(self.with_variations) > 0:
|
|
# use fixed initial noise plus random noise per iteration
|
|
seed_everything(seed)
|
|
initial_noise = self.get_noise(width,height)
|
|
for v_seed, v_weight in self.with_variations:
|
|
seed = v_seed
|
|
seed_everything(seed)
|
|
next_noise = self.get_noise(width,height)
|
|
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
|
|
if self.variation_amount > 0:
|
|
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
|
|
seed = random.randrange(0,np.iinfo(np.uint32).max)
|
|
return (seed, initial_noise)
|
|
else:
|
|
return (seed, None)
|
|
|
|
# returns a tensor filled with random numbers from a normal distribution
|
|
def get_noise(self,width,height):
|
|
"""
|
|
Returns a tensor filled with random numbers, either form a normal distribution
|
|
(txt2img) or from the latent image (img2img, inpaint)
|
|
"""
|
|
raise NotImplementedError("get_noise() must be implemented in a descendent class")
|
|
|
|
def get_perlin_noise(self,width,height):
|
|
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
|
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
|
|
|
def new_seed(self):
|
|
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
|
return self.seed
|
|
|
|
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
|
'''
|
|
Spherical linear interpolation
|
|
Args:
|
|
t (float/np.ndarray): Float value between 0.0 and 1.0
|
|
v0 (np.ndarray): Starting vector
|
|
v1 (np.ndarray): Final vector
|
|
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
|
colineal. Not recommended to alter this.
|
|
Returns:
|
|
v2 (np.ndarray): Interpolation vector between v0 and v1
|
|
'''
|
|
inputs_are_torch = False
|
|
if not isinstance(v0, np.ndarray):
|
|
inputs_are_torch = True
|
|
v0 = v0.detach().cpu().numpy()
|
|
if not isinstance(v1, np.ndarray):
|
|
inputs_are_torch = True
|
|
v1 = v1.detach().cpu().numpy()
|
|
|
|
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
|
if np.abs(dot) > DOT_THRESHOLD:
|
|
v2 = (1 - t) * v0 + t * v1
|
|
else:
|
|
theta_0 = np.arccos(dot)
|
|
sin_theta_0 = np.sin(theta_0)
|
|
theta_t = theta_0 * t
|
|
sin_theta_t = np.sin(theta_t)
|
|
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
|
s1 = sin_theta_t / sin_theta_0
|
|
v2 = s0 * v0 + s1 * v1
|
|
|
|
if inputs_are_torch:
|
|
v2 = torch.from_numpy(v2).to(self.model.device)
|
|
|
|
return v2
|
|
|
|
def safety_check(self,image:Image.Image):
|
|
'''
|
|
If the CompViz safety checker flags an NSFW image, we
|
|
blur it out.
|
|
'''
|
|
import diffusers
|
|
|
|
checker = self.safety_checker['checker']
|
|
extractor = self.safety_checker['extractor']
|
|
features = extractor([image], return_tensors="pt")
|
|
features.to(self.model.device)
|
|
|
|
# unfortunately checker requires the numpy version, so we have to convert back
|
|
x_image = np.array(image).astype(np.float32) / 255.0
|
|
x_image = x_image[None].transpose(0, 3, 1, 2)
|
|
|
|
diffusers.logging.set_verbosity_error()
|
|
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
|
|
if has_nsfw_concept[0]:
|
|
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
|
|
return self.blur(image)
|
|
else:
|
|
return image
|
|
|
|
def blur(self,input):
|
|
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
|
try:
|
|
caution = Image.open(CAUTION_IMG)
|
|
caution = caution.resize((caution.width // 2, caution.height //2))
|
|
blurry.paste(caution,(0,0),caution)
|
|
except FileNotFoundError:
|
|
pass
|
|
return blurry
|
|
|
|
# this is a handy routine for debugging use. Given a generated sample,
|
|
# convert it into a PNG image and store it at the indicated path
|
|
def save_sample(self, sample, filepath):
|
|
image = self.sample_to_image(sample)
|
|
dirname = os.path.dirname(filepath) or '.'
|
|
if not os.path.exists(dirname):
|
|
print(f'** creating directory {dirname}')
|
|
os.makedirs(dirname, exist_ok=True)
|
|
image.save(filepath,'PNG')
|
|
|
|
|