''' Base class for ldm.invoke.generator.* including img2img, txt2img, and inpaint ''' import torch import numpy as np import random import os from tqdm import tqdm, trange from PIL import Image from einops import rearrange, repeat from pytorch_lightning import seed_everything from ldm.invoke.devices import choose_autocast from ldm.util import rand_perlin_2d downsampling = 8 class Generator(): def __init__(self, model, precision): self.model = model self.precision = precision self.seed = None self.latent_channels = model.channels self.downsampling_factor = downsampling # BUG: should come from model or config self.perlin = 0.0 self.threshold = 0 self.variation_amount = 0 self.with_variations = [] # this is going to be overridden in img2img.py, txt2img.py and inpaint.py def get_make_image(self,prompt,**kwargs): """ Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it """ raise NotImplementedError("image_iterator() must be implemented in a descendent class") def set_variation(self, seed, variation_amount, with_variations): self.seed = seed self.variation_amount = variation_amount self.with_variations = with_variations def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, **kwargs): scope = choose_autocast(self.precision) make_image = self.get_make_image( prompt, sampler = sampler, init_image = init_image, width = width, height = height, step_callback = step_callback, threshold = threshold, perlin = perlin, **kwargs ) results = [] seed = seed if seed is not None else self.new_seed() first_seed = seed seed, initial_noise = self.generate_initial_noise(seed, width, height) scope = (scope(self.model.device.type), self.model.ema_scope()) if sampler.conditioning_key() not in ('hybrid','concat') else scope(self.model.device.type) with scope: for n in trange(iterations, desc='Generating'): print('DEBUG: in iterations loop() called') x_T = None if self.variation_amount > 0: seed_everything(seed) target_noise = self.get_noise(width,height) x_T = self.slerp(self.variation_amount, initial_noise, target_noise) elif initial_noise is not None: # i.e. we specified particular variations x_T = initial_noise else: seed_everything(seed) try: x_T = self.get_noise(width,height) except: pass image = make_image(x_T) results.append([image, seed]) if image_callback is not None: image_callback(image, seed, first_seed=first_seed) seed = self.new_seed() return results def sample_to_image(self,samples)->Image.Image: """ Given samples returned from a sampler, converts it into a PIL Image """ x_samples = self.model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) if len(x_samples) != 1: raise Exception( f'>> expected to get a single image, but got {len(x_samples)}') x_sample = 255.0 * rearrange( x_samples[0].cpu().numpy(), 'c h w -> h w c' ) return Image.fromarray(x_sample.astype(np.uint8)) def generate_initial_noise(self, seed, width, height): initial_noise = None if self.variation_amount > 0 or len(self.with_variations) > 0: # use fixed initial noise plus random noise per iteration seed_everything(seed) initial_noise = self.get_noise(width,height) for v_seed, v_weight in self.with_variations: seed = v_seed seed_everything(seed) next_noise = self.get_noise(width,height) initial_noise = self.slerp(v_weight, initial_noise, next_noise) if self.variation_amount > 0: random.seed() # reset RNG to an actually random state, so we can get a random seed for variations seed = random.randrange(0,np.iinfo(np.uint32).max) return (seed, initial_noise) else: return (seed, None) # returns a tensor filled with random numbers from a normal distribution def get_noise(self,width,height): """ Returns a tensor filled with random numbers, either form a normal distribution (txt2img) or from the latent image (img2img, inpaint) """ raise NotImplementedError("get_noise() must be implemented in a descendent class") def get_perlin_noise(self,width,height): fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device) def new_seed(self): self.seed = random.randrange(0, np.iinfo(np.uint32).max) return self.seed def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995): ''' Spherical linear interpolation Args: t (float/np.ndarray): Float value between 0.0 and 1.0 v0 (np.ndarray): Starting vector v1 (np.ndarray): Final vector DOT_THRESHOLD (float): Threshold for considering the two vectors as colineal. Not recommended to alter this. Returns: v2 (np.ndarray): Interpolation vector between v0 and v1 ''' inputs_are_torch = False if not isinstance(v0, np.ndarray): inputs_are_torch = True v0 = v0.detach().cpu().numpy() if not isinstance(v1, np.ndarray): inputs_are_torch = True v1 = v1.detach().cpu().numpy() dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) if np.abs(dot) > DOT_THRESHOLD: v2 = (1 - t) * v0 + t * v1 else: theta_0 = np.arccos(dot) sin_theta_0 = np.sin(theta_0) theta_t = theta_0 * t sin_theta_t = np.sin(theta_t) s0 = np.sin(theta_0 - theta_t) / sin_theta_0 s1 = sin_theta_t / sin_theta_0 v2 = s0 * v0 + s1 * v1 if inputs_are_torch: v2 = torch.from_numpy(v2).to(self.model.device) return v2 # this is a handy routine for debugging use. Given a generated sample, # convert it into a PNG image and store it at the indicated path def save_sample(self, sample, filepath): image = self.sample_to_image(sample) dirname = os.path.dirname(filepath) or '.' if not os.path.exists(dirname): print(f'** creating directory {dirname}') os.makedirs(dirname, exist_ok=True) image.save(filepath,'PNG')