InvokeAI/ldm/invoke/generator/base.py
Lincoln Stein 6f93dc7712 cleanup inpainting and img2img
- add a `--inpaint_replace` option that fills masked regions with
  latent noise. This allows radical changes to inpainted regions
  at the cost of losing context.
- fix up readline, arg processing and metadata writing to accommodate
  this change
- fixed bug in storage and retrieval of variations, discovered incidentally
  during testing
- update documentation
2022-10-16 08:50:55 -04:00

183 lines
7.3 KiB
Python

'''
Base class for ldm.invoke.generator.*
including img2img, txt2img, and inpaint
'''
import torch
import numpy as np
import random
import os
from tqdm import tqdm, trange
from PIL import Image
from einops import rearrange, repeat
from pytorch_lightning import seed_everything
from ldm.invoke.devices import choose_autocast
from ldm.util import rand_perlin_2d
downsampling = 8
class Generator():
def __init__(self, model, precision):
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
self.downsampling_factor = downsampling # BUG: should come from model or config
self.perlin = 0.0
self.threshold = 0
self.variation_amount = 0
self.with_variations = []
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
"""
raise NotImplementedError("image_iterator() must be implemented in a descendent class")
def set_variation(self, seed, variation_amount, with_variations):
self.seed = seed
self.variation_amount = variation_amount
self.with_variations = with_variations
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
**kwargs):
scope = choose_autocast(self.precision)
make_image = self.get_make_image(
prompt,
init_image = init_image,
width = width,
height = height,
step_callback = step_callback,
threshold = threshold,
perlin = perlin,
**kwargs
)
results = []
seed = seed if seed is not None else self.new_seed()
first_seed = seed
seed, initial_noise = self.generate_initial_noise(seed, width, height)
with scope(self.model.device.type), self.model.ema_scope():
for n in trange(iterations, desc='Generating'):
x_T = None
if self.variation_amount > 0:
seed_everything(seed)
target_noise = self.get_noise(width,height)
x_T = self.slerp(self.variation_amount, initial_noise, target_noise)
elif initial_noise is not None:
# i.e. we specified particular variations
x_T = initial_noise
else:
seed_everything(seed)
try:
x_T = self.get_noise(width,height)
except:
pass
image = make_image(x_T)
results.append([image, seed])
if image_callback is not None:
image_callback(image, seed, first_seed=first_seed)
seed = self.new_seed()
return results
def sample_to_image(self,samples):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
"""
x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if len(x_samples) != 1:
raise Exception(
f'>> expected to get a single image, but got {len(x_samples)}')
x_sample = 255.0 * rearrange(
x_samples[0].cpu().numpy(), 'c h w -> h w c'
)
return Image.fromarray(x_sample.astype(np.uint8))
def generate_initial_noise(self, seed, width, height):
initial_noise = None
if self.variation_amount > 0 or len(self.with_variations) > 0:
# use fixed initial noise plus random noise per iteration
seed_everything(seed)
initial_noise = self.get_noise(width,height)
for v_seed, v_weight in self.with_variations:
seed = v_seed
seed_everything(seed)
next_noise = self.get_noise(width,height)
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
if self.variation_amount > 0:
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
seed = random.randrange(0,np.iinfo(np.uint32).max)
return (seed, initial_noise)
else:
return (seed, None)
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height):
"""
Returns a tensor filled with random numbers, either form a normal distribution
(txt2img) or from the latent image (img2img, inpaint)
"""
raise NotImplementedError("get_noise() must be implemented in a descendent class")
def get_perlin_noise(self,width,height):
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
def new_seed(self):
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
return self.seed
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
'''
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
'''
inputs_are_torch = False
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1, np.ndarray):
inputs_are_torch = True
v1 = v1.detach().cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(self.model.device)
return v2
# this is a handy routine for debugging use. Given a generated sample,
# convert it into a PNG image and store it at the indicated path
def save_sample(self, sample, filepath):
image = self.sample_to_image(sample)
dirname = os.path.dirname(filepath) or '.'
if not os.path.exists(dirname):
print(f'** creating directory {dirname}')
os.makedirs(dirname, exist_ok=True)
image.save(filepath,'PNG')