""" Base class for invokeai.backend.generator.* including img2img, txt2img, and inpaint """ from __future__ import annotations import importlib import itertools import dataclasses import diffusers import os import random import traceback from abc import ABCMeta, abstractmethod from contextlib import nullcontext import cv2 import numpy as np import torch from PIL import Image, ImageChops, ImageFilter from accelerate.utils import set_seed from diffusers import DiffusionPipeline from tqdm import trange from typing import List, Type, Iterator from dataclasses import dataclass, field from diffusers.schedulers import SchedulerMixin as Scheduler from ..util.util import rand_perlin_2d from ..safety_checker import SafetyChecker from ..prompting.conditioning import get_uc_and_c_and_ec from ..model_management.model_manager import ModelManager from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline downsampling = 8 @dataclass class InvokeAIGeneratorBasicParams: seed: int=None width: int=512 height: int=512 cfg_scale: int=7.5 steps: int=20 ddim_eta: float=0.0 model_name: str='stable-diffusion-1.5' scheduler: int='ddim' precision: str='float16' perlin: float=0.0 threshold: int=0.0 h_symmetry_time_pct: float=None v_symmetry_time_pct: float=None variation_amount: float = 0.0 with_variations: list=field(default_factory=list) safety_checker: SafetyChecker=None @dataclass class InvokeAIGeneratorOutput: ''' InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation operation, including the image, its seed, the model name used to generate the image and the model hash, as well as all the generate() parameters that went into generating the image (in .params, also available as attributes) ''' image: Image seed: int model_name: str model_hash: str params: dict def __getattribute__(self,name): try: return object.__getattribute__(self, name) except AttributeError: params = object.__getattribute__(self, 'params') if name in params: return params[name] raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'") class InvokeAIGeneratorFactory(object): def __init__(self, model_manager: ModelManager, params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(), ): self.model_manager = model_manager self.params = params def make_generator(self, generatorclass: Type[InvokeAIGenerator], **keyword_args)->InvokeAIGenerator: return generatorclass(self.model_manager, self.params, **keyword_args ) # getter and setter shortcuts for commonly used parameters @property def model_name(self)->str: return self.params.model_name @model_name.setter def model_name(self, model_name: str): self.params.model_name=model_name # we are interposing a wrapper around the original Generator classes so that # old code that calls Generate will continue to work. class InvokeAIGenerator(metaclass=ABCMeta): scheduler_map = dict( ddim=diffusers.DDIMScheduler, dpmpp_2=diffusers.DPMSolverMultistepScheduler, k_dpm_2=diffusers.KDPM2DiscreteScheduler, k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler, k_dpmpp_2=diffusers.DPMSolverMultistepScheduler, k_euler=diffusers.EulerDiscreteScheduler, k_euler_a=diffusers.EulerAncestralDiscreteScheduler, k_heun=diffusers.HeunDiscreteScheduler, k_lms=diffusers.LMSDiscreteScheduler, plms=diffusers.PNDMScheduler, ) def __init__(self, model_manager: ModelManager, params: InvokeAIGeneratorBasicParams, ): self.model_manager=model_manager self.params=params def generate(self, prompt: str='', callback: callable=None, step_callback: callable=None, iterations: int=1, **keyword_args, )->Iterator[InvokeAIGeneratorOutput]: ''' Return an iterator across the indicated number of generations. Each time the iterator is called it will return an InvokeAIGeneratorOutput object. Use like this: outputs = txt2img.generate(prompt='banana sushi', iterations=5) for result in outputs: print(result.image, result.seed) In the typical case of wanting to get just a single image, iterations defaults to 1 and do: output = next(txt2img.generate(prompt='banana sushi') Pass None to get an infinite iterator. outputs = txt2img.generate(prompt='banana sushi', iterations=None) for o in outputs: print(o.image, o.seed) ''' model_name = self.params.model_name or self.model_manager.current_model model_info: dict = self.model_manager.get_model(model_name) model:StableDiffusionGeneratorPipeline = model_info['model'] model_hash = model_info['hash'] scheduler: Scheduler = self.get_scheduler( model=model, scheduler_name=self.params.scheduler ) uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model) def _wrap_results(image: Image, seed: int, **kwargs): nonlocal results results.append(output) generator = self.load_generator(model, self._generator_name()) if self.params.variation_amount > 0: generator.set_variation(self.params.seed, self.params.variation_amount, self.params.with_variations) generator_args = dataclasses.asdict(self.params) generator_args.update(keyword_args) iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1) for i in iteration_count: results = generator.generate(prompt, conditioning=(uc, c, extra_conditioning_info), sampler=scheduler, **generator_args, ) output = InvokeAIGeneratorOutput( image=results[0][0], seed=results[0][1], model_name = model_name, model_hash = model_hash, params=generator_args, ) if callback: callback(output) yield output @classmethod def schedulers(self)->List[str]: ''' Return list of all the schedulers that we currently handle. ''' return list(self.scheduler_map.keys()) def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str): module_name = f'invokeai.backend.generator.{class_name.lower()}' module = importlib.import_module(module_name) constructor = getattr(module, class_name) return constructor(model, self.params.precision) def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: scheduler_class = self.scheduler_map.get(scheduler_name,'ddim') scheduler = scheduler_class.from_config(model.scheduler.config) # hack copied over from generate.py if not hasattr(scheduler, 'uses_inpainting_model'): scheduler.uses_inpainting_model = lambda: False return scheduler @abstractmethod def _generator_name(self)->str: ''' In derived classes will return the name of the generator to use. ''' pass # ------------------------------------ class Txt2Img(InvokeAIGenerator): def _generator_name(self)->str: return 'Txt2Img' # ------------------------------------ class Img2Img(InvokeAIGenerator): def generate(self, init_image: Image | torch.FloatTensor, strength: float=0.75, **keyword_args )->List[InvokeAIGeneratorOutput]: return super().generate(init_image=init_image, strength=strength, **keyword_args ) def _generator_name(self)->str: return 'Img2Img' # ------------------------------------ # Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff class Inpaint(Img2Img): def generate(self, mask_image: Image | torch.FloatTensor, # Seam settings - when 0, doesn't fill seam seam_size: int = 0, seam_blur: int = 0, seam_strength: float = 0.7, seam_steps: int = 10, tile_size: int = 32, inpaint_replace=False, infill_method=None, inpaint_width=None, inpaint_height=None, inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF), **keyword_args )->List[InvokeAIGeneratorOutput]: return super().generate( mask_image=mask_image, seam_size=seam_size, seam_blur=seam_blur, seam_strength=seam_strength, seam_steps=seam_steps, tile_size=tile_size, inpaint_replace=inpaint_replace, infill_method=infill_method, inpaint_width=inpaint_width, inpaint_height=inpaint_height, inpaint_fill=inpaint_fill, **keyword_args ) def _generator_name(self)->str: return 'Inpaint' class Generator: downsampling_factor: int latent_channels: int precision: str model: DiffusionPipeline def __init__(self, model: DiffusionPipeline, precision: str): self.model = model self.precision = precision self.seed = None self.latent_channels = model.channels self.downsampling_factor = downsampling # BUG: should come from model or config self.safety_checker = None self.perlin = 0.0 self.threshold = 0 self.variation_amount = 0 self.with_variations = [] self.use_mps_noise = False self.free_gpu_mem = None # this is going to be overridden in img2img.py, txt2img.py and inpaint.py def get_make_image(self, prompt, **kwargs): """ Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it """ raise NotImplementedError( "image_iterator() must be implemented in a descendent class" ) def set_variation(self, seed, variation_amount, with_variations): self.seed = seed self.variation_amount = variation_amount self.with_variations = with_variations def generate( self, prompt, width, height, sampler, init_image=None, iterations=1, seed=None, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, h_symmetry_time_pct=None, v_symmetry_time_pct=None, safety_checker: SafetyChecker=None, free_gpu_mem: bool = False, **kwargs, ): scope = nullcontext self.safety_checker = safety_checker self.free_gpu_mem = free_gpu_mem attention_maps_images = [] attention_maps_callback = lambda saver: attention_maps_images.append( saver.get_stacked_maps_image() ) make_image = self.get_make_image( prompt, sampler=sampler, init_image=init_image, width=width, height=height, step_callback=step_callback, threshold=threshold, perlin=perlin, h_symmetry_time_pct=h_symmetry_time_pct, v_symmetry_time_pct=v_symmetry_time_pct, attention_maps_callback=attention_maps_callback, seed=seed, **kwargs, ) results = [] seed = seed if seed is not None and seed >= 0 else self.new_seed() first_seed = seed seed, initial_noise = self.generate_initial_noise(seed, width, height) # There used to be an additional self.model.ema_scope() here, but it breaks # the inpaint-1.5 model. Not sure what it did.... ? with scope(self.model.device.type): for n in trange(iterations, desc="Generating"): x_T = None if self.variation_amount > 0: set_seed(seed) target_noise = self.get_noise(width, height) x_T = self.slerp(self.variation_amount, initial_noise, target_noise) elif initial_noise is not None: # i.e. we specified particular variations x_T = initial_noise else: set_seed(seed) try: x_T = self.get_noise(width, height) except: print("** An error occurred while getting initial noise **") print(traceback.format_exc()) image = make_image(x_T) if self.safety_checker is not None: image = self.safety_checker.check(image) results.append([image, seed]) if image_callback is not None: attention_maps_image = ( None if len(attention_maps_images) == 0 else attention_maps_images[-1] ) image_callback( image, seed, first_seed=first_seed, attention_maps_image=attention_maps_image, ) seed = self.new_seed() # Free up memory from the last generation. clear_cuda_cache = ( kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None ) if clear_cuda_cache is not None: clear_cuda_cache() return results def sample_to_image(self, samples) -> Image.Image: """ Given samples returned from a sampler, converts it into a PIL Image """ with torch.inference_mode(): image = self.model.decode_latents(samples) return self.model.numpy_to_pil(image)[0] def repaste_and_color_correct( self, result: Image.Image, init_image: Image.Image, init_mask: Image.Image, mask_blur_radius: int = 8, ) -> Image.Image: if init_image is None or init_mask is None: return result # Get the original alpha channel of the mask if there is one. # Otherwise it is some other black/white image format ('1', 'L' or 'RGB') pil_init_mask = ( init_mask.getchannel("A") if init_mask.mode == "RGBA" else init_mask.convert("L") ) pil_init_image = init_image.convert( "RGBA" ) # Add an alpha channel if one doesn't exist # Build an image with only visible pixels from source to use as reference for color-matching. init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8) init_a_pixels = np.asarray(pil_init_image.getchannel("A"), dtype=np.uint8) init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8) # Get numpy version of result np_image = np.asarray(result, dtype=np.uint8) # Mask and calculate mean and standard deviation mask_pixels = init_a_pixels * init_mask_pixels > 0 np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :] np_image_masked = np_image[mask_pixels, :] if np_init_rgb_pixels_masked.size > 0: init_means = np_init_rgb_pixels_masked.mean(axis=0) init_std = np_init_rgb_pixels_masked.std(axis=0) gen_means = np_image_masked.mean(axis=0) gen_std = np_image_masked.std(axis=0) # Color correct np_matched_result = np_image.copy() np_matched_result[:, :, :] = ( ( ( ( np_matched_result[:, :, :].astype(np.float32) - gen_means[None, None, :] ) / gen_std[None, None, :] ) * init_std[None, None, :] + init_means[None, None, :] ) .clip(0, 255) .astype(np.uint8) ) matched_result = Image.fromarray(np_matched_result, mode="RGB") else: matched_result = Image.fromarray(np_image, mode="RGB") # Blur the mask out (into init image) by specified amount if mask_blur_radius > 0: nm = np.asarray(pil_init_mask, dtype=np.uint8) nmd = cv2.erode( nm, kernel=np.ones((3, 3), dtype=np.uint8), iterations=int(mask_blur_radius / 2), ) pmd = Image.fromarray(nmd, mode="L") blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius)) else: blurred_init_mask = pil_init_mask multiplied_blurred_init_mask = ImageChops.multiply( blurred_init_mask, self.pil_image.split()[-1] ) # Paste original on color-corrected generation (using blurred mask) matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask) return matched_result def sample_to_lowres_estimated_image(self, samples): # origingally adapted from code by @erucipe and @keturn here: # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 # these updated numbers for v1.5 are from @torridgristle v1_5_latent_rgb_factors = torch.tensor( [ # R G B [0.3444, 0.1385, 0.0670], # L1 [0.1247, 0.4027, 0.1494], # L2 [-0.3192, 0.2513, 0.2103], # L3 [-0.1307, -0.1874, -0.7445], # L4 ], dtype=samples.dtype, device=samples.device, ) latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors latents_ubyte = ( ((latent_image + 1) / 2) .clamp(0, 1) # change scale from -1..1 to 0..1 .mul(0xFF) # to 0..255 .byte() ).cpu() return Image.fromarray(latents_ubyte.numpy()) def generate_initial_noise(self, seed, width, height): initial_noise = None if self.variation_amount > 0 or len(self.with_variations) > 0: # use fixed initial noise plus random noise per iteration set_seed(seed) initial_noise = self.get_noise(width, height) for v_seed, v_weight in self.with_variations: seed = v_seed set_seed(seed) next_noise = self.get_noise(width, height) initial_noise = self.slerp(v_weight, initial_noise, next_noise) if self.variation_amount > 0: random.seed() # reset RNG to an actually random state, so we can get a random seed for variations seed = random.randrange(0, np.iinfo(np.uint32).max) return (seed, initial_noise) def get_perlin_noise(self, width, height): fixdevice = "cpu" if (self.model.device.type == "mps") else self.model.device # limit noise to only the diffusion image channels, not the mask channels input_channels = min(self.latent_channels, 4) # round up to the nearest block of 8 temp_width = int((width + 7) / 8) * 8 temp_height = int((height + 7) / 8) * 8 noise = torch.stack( [ rand_perlin_2d( (temp_height, temp_width), (8, 8), device=self.model.device ).to(fixdevice) for _ in range(input_channels) ], dim=0, ).to(self.model.device) return noise[0:4, 0:height, 0:width] def new_seed(self): self.seed = random.randrange(0, np.iinfo(np.uint32).max) return self.seed def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995): """ Spherical linear interpolation Args: t (float/np.ndarray): Float value between 0.0 and 1.0 v0 (np.ndarray): Starting vector v1 (np.ndarray): Final vector DOT_THRESHOLD (float): Threshold for considering the two vectors as colineal. Not recommended to alter this. Returns: v2 (np.ndarray): Interpolation vector between v0 and v1 """ inputs_are_torch = False if not isinstance(v0, np.ndarray): inputs_are_torch = True v0 = v0.detach().cpu().numpy() if not isinstance(v1, np.ndarray): inputs_are_torch = True v1 = v1.detach().cpu().numpy() dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) if np.abs(dot) > DOT_THRESHOLD: v2 = (1 - t) * v0 + t * v1 else: theta_0 = np.arccos(dot) sin_theta_0 = np.sin(theta_0) theta_t = theta_0 * t sin_theta_t = np.sin(theta_t) s0 = np.sin(theta_0 - theta_t) / sin_theta_0 s1 = sin_theta_t / sin_theta_0 v2 = s0 * v0 + s1 * v1 if inputs_are_torch: v2 = torch.from_numpy(v2).to(self.model.device) return v2 # this is a handy routine for debugging use. Given a generated sample, # convert it into a PNG image and store it at the indicated path def save_sample(self, sample, filepath): image = self.sample_to_image(sample) dirname = os.path.dirname(filepath) or "." if not os.path.exists(dirname): print(f"** creating directory {dirname}") os.makedirs(dirname, exist_ok=True) image.save(filepath, "PNG") def torch_dtype(self) -> torch.dtype: return torch.float16 if self.precision == "float16" else torch.float32 # returns a tensor filled with random numbers from a normal distribution def get_noise(self, width, height): device = self.model.device # limit noise to only the diffusion image channels, not the mask channels input_channels = min(self.latent_channels, 4) if self.use_mps_noise or device.type == "mps": x = torch.randn( [ 1, input_channels, height // self.downsampling_factor, width // self.downsampling_factor, ], dtype=self.torch_dtype(), device="cpu", ).to(device) else: x = torch.randn( [ 1, input_channels, height // self.downsampling_factor, width // self.downsampling_factor, ], dtype=self.torch_dtype(), device=device, ) if self.perlin > 0.0: perlin_noise = self.get_perlin_noise( width // self.downsampling_factor, height // self.downsampling_factor ) x = (1 - self.perlin) * x + self.perlin * perlin_noise return x