mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add more missing files
This commit is contained in:
parent
d334f7f1f6
commit
3b921cf393
6
invokeai/generator/__init__.py
Normal file
6
invokeai/generator/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
'''
|
||||
Initialization file for the invokeai.generator package
|
||||
'''
|
||||
from .base import Generator
|
||||
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline
|
||||
from .inpaint import infill_methods
|
374
invokeai/generator/base.py
Normal file
374
invokeai/generator/base.py
Normal file
@ -0,0 +1,374 @@
|
||||
'''
|
||||
Base class for invokeai.backend.generator.*
|
||||
including img2img, txt2img, and inpaint
|
||||
'''
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from PIL import Image, ImageFilter, ImageChops
|
||||
from diffusers import DiffusionPipeline
|
||||
from einops import rearrange
|
||||
from pathlib import Path
|
||||
from pytorch_lightning import seed_everything
|
||||
from tqdm import trange
|
||||
|
||||
import invokeai.assets.web as web_assets
|
||||
from ..models.diffusion.ddpm import DiffusionWrapper
|
||||
from ldm.util import rand_perlin_2d
|
||||
|
||||
downsampling = 8
|
||||
CAUTION_IMG = 'caution.png'
|
||||
|
||||
class Generator:
|
||||
downsampling_factor: int
|
||||
latent_channels: int
|
||||
precision: str
|
||||
model: DiffusionWrapper | DiffusionPipeline
|
||||
|
||||
def __init__(self, model: DiffusionWrapper | DiffusionPipeline, precision: str):
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.safety_checker = None
|
||||
self.perlin = 0.0
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
self.use_mps_noise = False
|
||||
self.free_gpu_mem = None
|
||||
self.caution_img = None
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self,prompt,**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
raise NotImplementedError("image_iterator() must be implemented in a descendent class")
|
||||
|
||||
def set_variation(self, seed, variation_amount, with_variations):
|
||||
self.seed = seed
|
||||
self.variation_amount = variation_amount
|
||||
self.with_variations = with_variations
|
||||
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
h_symmetry_time_pct=None, v_symmetry_time_pct=None,
|
||||
safety_checker:dict=None,
|
||||
free_gpu_mem: bool=False,
|
||||
**kwargs):
|
||||
scope = nullcontext
|
||||
self.safety_checker = safety_checker
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler = sampler,
|
||||
init_image = init_image,
|
||||
width = width,
|
||||
height = height,
|
||||
step_callback = step_callback,
|
||||
threshold = threshold,
|
||||
perlin = perlin,
|
||||
h_symmetry_time_pct = h_symmetry_time_pct,
|
||||
v_symmetry_time_pct = v_symmetry_time_pct,
|
||||
attention_maps_callback = attention_maps_callback,
|
||||
**kwargs
|
||||
)
|
||||
results = []
|
||||
seed = seed if seed is not None and seed >= 0 else self.new_seed()
|
||||
first_seed = seed
|
||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||
|
||||
# There used to be an additional self.model.ema_scope() here, but it breaks
|
||||
# the inpaint-1.5 model. Not sure what it did.... ?
|
||||
with scope(self.model.device.type):
|
||||
for n in trange(iterations, desc='Generating'):
|
||||
x_T = None
|
||||
if self.variation_amount > 0:
|
||||
seed_everything(seed)
|
||||
target_noise = self.get_noise(width,height)
|
||||
x_T = self.slerp(self.variation_amount, initial_noise, target_noise)
|
||||
elif initial_noise is not None:
|
||||
# i.e. we specified particular variations
|
||||
x_T = initial_noise
|
||||
else:
|
||||
seed_everything(seed)
|
||||
try:
|
||||
x_T = self.get_noise(width,height)
|
||||
except:
|
||||
print('** An error occurred while getting initial noise **')
|
||||
print(traceback.format_exc())
|
||||
|
||||
image = make_image(x_T)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
image = self.safety_check(image)
|
||||
|
||||
results.append([image, seed])
|
||||
|
||||
if image_callback is not None:
|
||||
attention_maps_image = None if len(attention_maps_images)==0 else attention_maps_images[-1]
|
||||
image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_image)
|
||||
|
||||
seed = self.new_seed()
|
||||
|
||||
# Free up memory from the last generation.
|
||||
clear_cuda_cache = kwargs['clear_cuda_cache'] if 'clear_cuda_cache' in kwargs else None
|
||||
if clear_cuda_cache is not None:
|
||||
clear_cuda_cache()
|
||||
|
||||
return results
|
||||
|
||||
def sample_to_image(self,samples)->Image.Image:
|
||||
"""
|
||||
Given samples returned from a sampler, converts
|
||||
it into a PIL Image
|
||||
"""
|
||||
with torch.inference_mode():
|
||||
image = self.model.decode_latents(samples)
|
||||
return self.model.numpy_to_pil(image)[0]
|
||||
|
||||
def repaste_and_color_correct(self, result: Image.Image, init_image: Image.Image, init_mask: Image.Image, mask_blur_radius: int = 8) -> Image.Image:
|
||||
if init_image is None or init_mask is None:
|
||||
return result
|
||||
|
||||
# Get the original alpha channel of the mask if there is one.
|
||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||
pil_init_mask = init_mask.getchannel('A') if init_mask.mode == 'RGBA' else init_mask.convert('L')
|
||||
pil_init_image = init_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||
|
||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||
init_rgb_pixels = np.asarray(init_image.convert('RGB'), dtype=np.uint8)
|
||||
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8)
|
||||
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
|
||||
# Get numpy version of result
|
||||
np_image = np.asarray(result, dtype=np.uint8)
|
||||
|
||||
# Mask and calculate mean and standard deviation
|
||||
mask_pixels = init_a_pixels * init_mask_pixels > 0
|
||||
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
|
||||
np_image_masked = np_image[mask_pixels, :]
|
||||
|
||||
if np_init_rgb_pixels_masked.size > 0:
|
||||
init_means = np_init_rgb_pixels_masked.mean(axis=0)
|
||||
init_std = np_init_rgb_pixels_masked.std(axis=0)
|
||||
gen_means = np_image_masked.mean(axis=0)
|
||||
gen_std = np_image_masked.std(axis=0)
|
||||
|
||||
# Color correct
|
||||
np_matched_result = np_image.copy()
|
||||
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
|
||||
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
||||
else:
|
||||
matched_result = Image.fromarray(np_image, mode='RGB')
|
||||
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
if mask_blur_radius > 0:
|
||||
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
nmd = cv2.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||
pmd = Image.fromarray(nmd, mode='L')
|
||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||
else:
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
def sample_to_lowres_estimated_image(self,samples):
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor([
|
||||
# R G B
|
||||
[ 0.3444, 0.1385, 0.0670], # L1
|
||||
[ 0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445] # L4
|
||||
], dtype=samples.dtype, device=samples.device)
|
||||
|
||||
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
||||
latents_ubyte = (((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
def generate_initial_noise(self, seed, width, height):
|
||||
initial_noise = None
|
||||
if self.variation_amount > 0 or len(self.with_variations) > 0:
|
||||
# use fixed initial noise plus random noise per iteration
|
||||
seed_everything(seed)
|
||||
initial_noise = self.get_noise(width,height)
|
||||
for v_seed, v_weight in self.with_variations:
|
||||
seed = v_seed
|
||||
seed_everything(seed)
|
||||
next_noise = self.get_noise(width,height)
|
||||
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
|
||||
if self.variation_amount > 0:
|
||||
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
|
||||
seed = random.randrange(0,np.iinfo(np.uint32).max)
|
||||
return (seed, initial_noise)
|
||||
else:
|
||||
return (seed, None)
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height):
|
||||
"""
|
||||
Returns a tensor filled with random numbers, either form a normal distribution
|
||||
(txt2img) or from the latent image (img2img, inpaint)
|
||||
"""
|
||||
raise NotImplementedError("get_noise() must be implemented in a descendent class")
|
||||
|
||||
def get_perlin_noise(self,width,height):
|
||||
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(self.latent_channels, 4)
|
||||
# round up to the nearest block of 8
|
||||
temp_width = int((width + 7) / 8) * 8
|
||||
temp_height = int((height + 7) / 8) * 8
|
||||
noise = torch.stack([
|
||||
rand_perlin_2d((temp_height, temp_width),
|
||||
(8, 8),
|
||||
device = self.model.device).to(fixdevice) for _ in range(input_channels)], dim=0).to(self.model.device)
|
||||
return noise[0:4, 0:height, 0:width]
|
||||
|
||||
def new_seed(self):
|
||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
return self.seed
|
||||
|
||||
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
'''
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
v0 (np.ndarray): Starting vector
|
||||
v1 (np.ndarray): Final vector
|
||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
'''
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
if not isinstance(v1, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(self.model.device)
|
||||
|
||||
return v2
|
||||
|
||||
def safety_check(self,image:Image.Image):
|
||||
'''
|
||||
If the CompViz safety checker flags an NSFW image, we
|
||||
blur it out.
|
||||
'''
|
||||
import diffusers
|
||||
|
||||
checker = self.safety_checker['checker']
|
||||
extractor = self.safety_checker['extractor']
|
||||
features = extractor([image], return_tensors="pt")
|
||||
features.to(self.model.device)
|
||||
|
||||
# unfortunately checker requires the numpy version, so we have to convert back
|
||||
x_image = np.array(image).astype(np.float32) / 255.0
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
|
||||
if has_nsfw_concept[0]:
|
||||
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
|
||||
return self.blur(image)
|
||||
else:
|
||||
return image
|
||||
|
||||
def blur(self,input):
|
||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
try:
|
||||
caution = self.get_caution_img()
|
||||
if caution:
|
||||
blurry.paste(caution,(0,0),caution)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return blurry
|
||||
|
||||
def get_caution_img(self):
|
||||
path = None
|
||||
if self.caution_img:
|
||||
return self.caution_img
|
||||
path = Path(web_assets.__path__[0]) / CAUTION_IMG
|
||||
caution = Image.open(path)
|
||||
self.caution_img = caution.resize((caution.width // 2, caution.height //2))
|
||||
return self.caution_img
|
||||
|
||||
# this is a handy routine for debugging use. Given a generated sample,
|
||||
# convert it into a PNG image and store it at the indicated path
|
||||
def save_sample(self, sample, filepath):
|
||||
image = self.sample_to_image(sample)
|
||||
dirname = os.path.dirname(filepath) or '.'
|
||||
if not os.path.exists(dirname):
|
||||
print(f'** creating directory {dirname}')
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
image.save(filepath,'PNG')
|
||||
|
||||
|
||||
def torch_dtype(self)->torch.dtype:
|
||||
return torch.float16 if self.precision == 'float16' else torch.float32
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(self.latent_channels, 4)
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
x = torch.randn([1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
dtype=self.torch_dtype(),
|
||||
device='cpu').to(device)
|
||||
else:
|
||||
x = torch.randn([1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
dtype=self.torch_dtype(),
|
||||
device=device)
|
||||
if self.perlin > 0.0:
|
||||
perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||
x = (1-self.perlin)*x + self.perlin*perlin_noise
|
||||
return x
|
764
invokeai/generator/diffusers_pipeline.py
Normal file
764
invokeai/generator/diffusers_pipeline.py
Normal file
@ -0,0 +1,764 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import psutil
|
||||
import secrets
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any
|
||||
|
||||
import PIL.Image
|
||||
import einops
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from invokeai.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
from ldm.invoke.devices import normalize_device, CPU_DEVICE
|
||||
from ldm.invoke.offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
|
||||
from ..models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from compel import EmbeddingsProvider
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
run_id: str
|
||||
step: int
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
|
||||
# copied from configs/stable-diffusion/v1-inference.yaml
|
||||
_default_personalization_config_params = dict(
|
||||
placeholder_strings=["*"],
|
||||
initializer_wods=["sculpture"],
|
||||
per_image_tokens=False,
|
||||
num_vectors_per_token=1,
|
||||
progressive_words=False
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskLatents:
|
||||
"""Add the channels required for inpainting model input.
|
||||
|
||||
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
|
||||
and the latent encoding of the base image.
|
||||
|
||||
This class assumes the same mask and base image should apply to all items in the batch.
|
||||
"""
|
||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
mask: torch.Tensor
|
||||
initial_image_latents: torch.Tensor
|
||||
|
||||
def __call__(self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor) -> torch.Tensor:
|
||||
model_input = self.add_mask_channels(latents)
|
||||
return self.forward(model_input, t, text_embeddings)
|
||||
|
||||
def add_mask_channels(self, latents):
|
||||
batch_size = latents.size(0)
|
||||
# duplicate mask and latents for each batch
|
||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
image_latents = einops.repeat(self.initial_image_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
# add mask and image as additional channels
|
||||
model_input, _ = einops.pack([latents, mask, image_latents], 'b * h w')
|
||||
return model_input
|
||||
|
||||
|
||||
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
|
||||
return (
|
||||
isinstance(b, torch.Tensor)
|
||||
and (a.size() == b.size())
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class AddsMaskGuidance:
|
||||
mask: torch.FloatTensor
|
||||
mask_latents: torch.FloatTensor
|
||||
scheduler: SchedulerMixin
|
||||
noise: torch.Tensor
|
||||
_debug: Optional[Callable] = None
|
||||
|
||||
def __call__(self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning) -> BaseOutput:
|
||||
output_class = step_output.__class__ # We'll create a new one with masked data.
|
||||
|
||||
# The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
|
||||
# It's reasonable to assume the first thing is prev_sample, but then does it have other things
|
||||
# like pred_original_sample? Should we apply the mask to them too?
|
||||
# But what if there's just some other random field?
|
||||
prev_sample = step_output[0]
|
||||
# Mask anything that has the same shape as prev_sample, return others as-is.
|
||||
return output_class(
|
||||
{k: (self.apply_mask(v, self._t_for_field(k, t))
|
||||
if are_like_tensors(prev_sample, v) else v)
|
||||
for k, v in step_output.items()}
|
||||
)
|
||||
|
||||
def _t_for_field(self, field_name:str, t):
|
||||
if field_name == "pred_original_sample":
|
||||
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0
|
||||
return t
|
||||
|
||||
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
if t.dim() == 0:
|
||||
# some schedulers expect t to be one-dimensional.
|
||||
# TODO: file diffusers bug about inconsistency?
|
||||
t = einops.repeat(t, '-> batch', batch=batch_size)
|
||||
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
||||
# get very confused about what is happening from step to step when we do that.
|
||||
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
|
||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||
if self._debug:
|
||||
self._debug(masked_input, f"t={t} lerped")
|
||||
return masked_input
|
||||
|
||||
|
||||
def trim_to_multiple_of(*args, multiple_of=8):
|
||||
return tuple((x - x % multiple_of) for x in args)
|
||||
|
||||
|
||||
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
|
||||
"""
|
||||
|
||||
:param image: input image
|
||||
:param normalize: scale the range to [-1, 1] instead of [0, 1]
|
||||
:param multiple_of: resize the input so both dimensions are a multiple of this
|
||||
"""
|
||||
w, h = trim_to_multiple_of(*image.size)
|
||||
transformation = T.Compose([
|
||||
T.Resize((h, w), T.InterpolationMode.LANCZOS),
|
||||
T.ToTensor(),
|
||||
])
|
||||
tensor = transformation(image)
|
||||
if normalize:
|
||||
tensor = tensor * 2.0 - 1.0
|
||||
return tensor
|
||||
|
||||
|
||||
def is_inpainting_model(unet: UNet2DConditionModel):
|
||||
return unet.conv_in.in_channels == 9
|
||||
|
||||
CallbackType = TypeVar('CallbackType')
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
ParamType = ParamSpec('ParamType')
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
"""Convert a generator to a function with a callback and a return value."""
|
||||
|
||||
generator_method: Callable[ParamType, ReturnType]
|
||||
callback_arg_type: Type[CallbackType]
|
||||
|
||||
def __call__(self, *args: ParamType.args,
|
||||
callback:Callable[[CallbackType], Any]=None,
|
||||
**kwargs: ParamType.kwargs) -> ReturnType:
|
||||
result = None
|
||||
for result in self.generator_method(*args, **kwargs):
|
||||
if callback is not None and isinstance(result, self.callback_arg_type):
|
||||
callback(result)
|
||||
if result is None:
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: torch.Tensor
|
||||
text_embeddings: torch.Tensor
|
||||
guidance_scale: float
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
"""
|
||||
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
"""
|
||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||
"""
|
||||
postprocessing_settings: Optional[PostprocessingSettings] = None
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.text_embeddings.dtype
|
||||
|
||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
||||
scheduler_args = dict(self.scheduler_args)
|
||||
step_method = inspect.signature(scheduler.step)
|
||||
for name, value in kwargs.items():
|
||||
try:
|
||||
step_method.bind_partial(**{name: value})
|
||||
except TypeError:
|
||||
# FIXME: don't silently discard arguments
|
||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
||||
else:
|
||||
scheduler_args[name] = value
|
||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
||||
|
||||
@dataclass
|
||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||
r"""
|
||||
Output class for InvokeAI's Stable Diffusion pipeline.
|
||||
|
||||
Args:
|
||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||
after generation completes. Optional.
|
||||
"""
|
||||
attention_map_saver: Optional[AttentionMapSaver]
|
||||
|
||||
|
||||
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline.
|
||||
Hopefully future versions of diffusers provide access to more of these functions so that we don't
|
||||
need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_model_group: ModelGroup
|
||||
|
||||
ID_LENGTH = 8
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||
requires_safety_checker: bool = False,
|
||||
precision: str = 'float32',
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
|
||||
safety_checker, feature_extractor, requires_safety_checker)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True)
|
||||
use_full_precision = (precision == 'float32' or precision == 'autocast')
|
||||
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
full_precision=use_full_precision)
|
||||
# InvokeAI's interface for text embeddings and whatnot
|
||||
self.embeddings_provider = EmbeddingsProvider(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
textual_inversion_manager=self.textual_inversion_manager
|
||||
)
|
||||
|
||||
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
||||
self._model_group.install(*self._submodels)
|
||||
|
||||
|
||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||
"""
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
if torch.cuda.is_available() and is_xformers_available() and not Globals.disable_xformers:
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
if torch.backends.mps.is_available():
|
||||
# until pytorch #91617 is fixed, slicing is borked on MPS
|
||||
# https://github.com/pytorch/pytorch/issues/91617
|
||||
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
|
||||
pass
|
||||
else:
|
||||
if self.device.type == 'cpu' or self.device.type == 'mps':
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.device.type == 'cuda':
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {self.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||
max_size_required_for_baddbmm = \
|
||||
16 * \
|
||||
latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * \
|
||||
bytes_per_element_needed_for_baddbmm_duplication
|
||||
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
|
||||
self.enable_attention_slicing(slice_size='max')
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
|
||||
|
||||
def enable_offload_submodels(self, device: torch.device):
|
||||
"""
|
||||
Offload each submodel when it's not in use.
|
||||
|
||||
Useful for low-vRAM situations where the size of the model in memory is a big chunk of
|
||||
the total available resource, and you want to free up as much for inference as possible.
|
||||
|
||||
This requires more moving parts and may add some delay as the U-Net is swapped out for the
|
||||
VAE and vice-versa.
|
||||
"""
|
||||
models = self._submodels
|
||||
if self._model_group is not None:
|
||||
self._model_group.uninstall(*models)
|
||||
group = LazilyLoadedModelGroup(device)
|
||||
group.install(*models)
|
||||
self._model_group = group
|
||||
|
||||
def disable_offload_submodels(self):
|
||||
"""
|
||||
Leave all submodels loaded.
|
||||
|
||||
Appropriate for cases where the size of the model in memory is small compared to the memory
|
||||
required for inference. Avoids the delay and complexity of shuffling the submodels to and
|
||||
from the GPU.
|
||||
"""
|
||||
models = self._submodels
|
||||
if self._model_group is not None:
|
||||
self._model_group.uninstall(*models)
|
||||
group = FullyLoadedModelGroup(self._model_group.execution_device)
|
||||
group.install(*models)
|
||||
self._model_group = group
|
||||
|
||||
def offload_all(self):
|
||||
"""Offload all this pipeline's models to CPU."""
|
||||
self._model_group.offload_current()
|
||||
|
||||
def ready(self):
|
||||
"""
|
||||
Ready this pipeline's models.
|
||||
|
||||
i.e. pre-load them to the GPU if appropriate.
|
||||
"""
|
||||
self._model_group.ready()
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
||||
# overridden method; types match the superclass.
|
||||
if torch_device is None:
|
||||
return self
|
||||
self._model_group.set_device(torch.device(torch_device))
|
||||
self._model_group.ready()
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._model_group.execution_device
|
||||
|
||||
@property
|
||||
def _submodels(self) -> Sequence[torch.nn.Module]:
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
values = [getattr(self, name) for name in module_names.keys()]
|
||||
return [m for m in values if isinstance(m, torch.nn.Module)]
|
||||
|
||||
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None]=None,
|
||||
run_id=None) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
:param conditioning_data:
|
||||
:param latents: Pre-generated un-noised latents, to be used as inputs for
|
||||
image generation. Can be used to tweak the same generation with different prompts.
|
||||
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
|
||||
image at the expense of slower inference.
|
||||
:param noise: Noise to add to the latents, sampled from a Gaussian distribution.
|
||||
:param callback:
|
||||
:param run_id:
|
||||
"""
|
||||
result_latents, result_attention_map_saver = self.latents_from_embeddings(
|
||||
latents, num_inference_steps,
|
||||
conditioning_data,
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback)
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_map_saver)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
timesteps=None,
|
||||
additional_guidance: List[Callable] = None, run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if timesteps is None:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self._model_group.device_for(self.unet))
|
||||
timesteps = self.scheduler.timesteps
|
||||
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
||||
result: PipelineIntermediateState = infer_latents_from_embeddings(
|
||||
latents, timesteps, conditioning_data,
|
||||
noise=noise,
|
||||
additional_guidance=additional_guidance,
|
||||
run_id=run_id,
|
||||
callback=callback)
|
||||
return result.latents, result.attention_map_saver
|
||||
|
||||
def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if run_id is None:
|
||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps)
|
||||
):
|
||||
|
||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||
latents=latents)
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
batched_t = torch.full((batch_size,), timesteps[0],
|
||||
dtype=timesteps.dtype, device=self._model_group.device_for(self.unet))
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(batched_t, latents, conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance)
|
||||
latents = step_output.prev_sample
|
||||
|
||||
latents = self.invokeai_diffuser.do_latent_postprocessing(
|
||||
postprocessing_settings=conditioning_data.postprocessing_settings,
|
||||
latents=latents,
|
||||
sigma=batched_t,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps)
|
||||
)
|
||||
|
||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||
|
||||
# TODO resuscitate attention map saving
|
||||
#if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
# attention_map_token_ids = range(1, eos_token_index)
|
||||
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
|
||||
|
||||
return latents, attention_map_saver
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(self, t: torch.Tensor, latents: torch.Tensor,
|
||||
conditioning_data: ConditioningData,
|
||||
step_index:int, total_step_count:int,
|
||||
additional_guidance: List[Callable] = None):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||
latent_model_input, t,
|
||||
conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings,
|
||||
conditioning_data.guidance_scale,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents,
|
||||
**conditioning_data.scheduler_args)
|
||||
|
||||
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
|
||||
# But the way things are now, scheduler runs _after_ that, so there was
|
||||
# no way to use it to apply an operation that happens after the last scheduler.step.
|
||||
for guidance in additional_guidance:
|
||||
step_output = guidance(step_output, timestep, conditioning_data)
|
||||
|
||||
return step_output
|
||||
|
||||
def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Optional[dict[str,Any]] = None):
|
||||
"""predict the noise residual"""
|
||||
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
||||
# Pad out normal non-inpainting inputs for an inpainting model.
|
||||
# FIXME: There are too many layers of functions and we have too many different ways of
|
||||
# overriding things! This should get handled in a way more consistent with the other
|
||||
# use of AddsMaskLatents.
|
||||
latents = AddsMaskLatents(
|
||||
self._unet_forward,
|
||||
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
|
||||
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
|
||||
).add_mask_channels(latents)
|
||||
|
||||
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||
return self.unet(latents, t, text_embeddings,
|
||||
cross_attention_kwargs=cross_attention_kwargs).sample
|
||||
|
||||
def img2img_from_embeddings(self,
|
||||
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*, callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
noise_func=None
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||
|
||||
if init_image.dim() == 3:
|
||||
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
|
||||
|
||||
# 6. Prepare latent variables
|
||||
initial_latents = self.non_noised_latents_from_image(
|
||||
init_image, device=self._model_group.device_for(self.unet),
|
||||
dtype=self.unet.dtype)
|
||||
noise = noise_func(initial_latents)
|
||||
|
||||
return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
|
||||
conditioning_data,
|
||||
strength,
|
||||
noise, run_id, callback)
|
||||
|
||||
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps,
|
||||
conditioning_data: ConditioningData,
|
||||
strength,
|
||||
noise: torch.Tensor, run_id=None, callback=None
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength,
|
||||
device=self._model_group.device_for(self.unet))
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
initial_latents, num_inference_steps, conditioning_data,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device) -> (torch.Tensor, int):
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
assert img2img_pipeline.scheduler is self.scheduler
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
||||
# Workaround for low strength resulting in zero timesteps.
|
||||
# TODO: submit upstream fix for zero-step img2img
|
||||
if timesteps.numel() == 0:
|
||||
timesteps = self.scheduler.timesteps[-1:]
|
||||
adjusted_steps = timesteps.numel()
|
||||
return timesteps, adjusted_steps
|
||||
|
||||
def inpaint_from_embeddings(
|
||||
self,
|
||||
init_image: torch.FloatTensor,
|
||||
mask: torch.FloatTensor,
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*, callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
noise_func=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
device = self._model_group.device_for(self.unet)
|
||||
latents_dtype = self.unet.dtype
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||
|
||||
init_image = init_image.to(device=device, dtype=latents_dtype)
|
||||
mask = mask.to(device=device, dtype=latents_dtype)
|
||||
|
||||
if init_image.dim() == 3:
|
||||
init_image = init_image.unsqueeze(0)
|
||||
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, device=device)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||
# because we have our own noise function
|
||||
init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
||||
noise = noise_func(init_image_latents)
|
||||
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
latent_mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \
|
||||
.to(device=device, dtype=latents_dtype)
|
||||
|
||||
guidance: List[Callable] = []
|
||||
|
||||
if is_inpainting_model(self.unet):
|
||||
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
|
||||
# (that's why there's a mask!) but it seems to really want that blanked out.
|
||||
masked_init_image = init_image * torch.where(mask < 0.5, 1, 0)
|
||||
masked_latents = self.non_noised_latents_from_image(masked_init_image, device=device, dtype=latents_dtype)
|
||||
|
||||
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
||||
self.invokeai_diffuser.model_forward_callback = \
|
||||
AddsMaskLatents(self._unet_forward, latent_mask, masked_latents)
|
||||
else:
|
||||
guidance.append(AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise))
|
||||
|
||||
try:
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
init_image_latents, num_inference_steps,
|
||||
conditioning_data, noise=noise, timesteps=timesteps,
|
||||
additional_guidance=guidance,
|
||||
run_id=run_id, callback=callback)
|
||||
finally:
|
||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
with torch.inference_mode():
|
||||
if device.type == 'mps':
|
||||
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
|
||||
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
|
||||
self.vae.to(CPU_DEVICE)
|
||||
init_image = init_image.to(CPU_DEVICE)
|
||||
else:
|
||||
self._model_group.load(self.vae)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
if device.type == 'mps':
|
||||
self.vae.to(device)
|
||||
init_latents = init_latents.to(device)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
return init_latents
|
||||
|
||||
def check_for_safety(self, output, dtype):
|
||||
with torch.inference_mode():
|
||||
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
|
||||
screened_attention_map_saver = None
|
||||
if has_nsfw_concept is None or not has_nsfw_concept:
|
||||
screened_attention_map_saver = output.attention_map_saver
|
||||
return InvokeAIStableDiffusionPipelineOutput(screened_images,
|
||||
has_nsfw_concept,
|
||||
# block the attention maps if NSFW content is detected
|
||||
attention_map_saver=screened_attention_map_saver)
|
||||
|
||||
def run_safety_checker(self, image, device=None, dtype=None):
|
||||
# overriding to use the model group for device info instead of requiring the caller to know.
|
||||
if self.safety_checker is not None:
|
||||
device = self._model_group.device_for(self.safety_checker)
|
||||
return super().run_safety_checker(image, device, dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
||||
"""
|
||||
Compatibility function for invokeai.models.diffusion.ddpm.LatentDiffusion.
|
||||
"""
|
||||
return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments(
|
||||
text_batch=c,
|
||||
fragment_weights_batch=fragment_weights,
|
||||
should_return_tokens=return_tokens,
|
||||
device=self._model_group.device_for(self.unet))
|
||||
|
||||
@property
|
||||
def cond_stage_model(self):
|
||||
return self.embeddings_provider
|
||||
|
||||
@torch.inference_mode()
|
||||
def _tokenize(self, prompt: Union[str, List[str]]):
|
||||
return self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
@property
|
||||
def channels(self) -> int:
|
||||
"""Compatible with DiffusionWrapper"""
|
||||
return self.unet.in_channels
|
||||
|
||||
def decode_latents(self, latents):
|
||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||
self._model_group.load(self.vae)
|
||||
return super().decode_latents(latents)
|
||||
|
||||
def debug_latents(self, latents, msg):
|
||||
with torch.inference_mode():
|
||||
from ldm.util import debug_image
|
||||
decoded = self.numpy_to_pil(self.decode_latents(latents))
|
||||
for i, img in enumerate(decoded):
|
||||
debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True)
|
501
invokeai/generator/embiggen.py
Normal file
501
invokeai/generator/embiggen.py
Normal file
@ -0,0 +1,501 @@
|
||||
'''
|
||||
invokeai.backend.generator.embiggen descends from ldm.invoke.generator
|
||||
and generates with invokeai.backend.generator.img2img
|
||||
'''
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import trange
|
||||
|
||||
from .base import Generator
|
||||
from .img2img import Img2Img
|
||||
|
||||
|
||||
class Embiggen(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None
|
||||
|
||||
# Replace generate because Embiggen doesn't need/use most of what it does normallly
|
||||
def generate(self,prompt,iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None,
|
||||
**kwargs):
|
||||
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
step_callback = step_callback,
|
||||
**kwargs
|
||||
)
|
||||
results = []
|
||||
seed = seed if seed else self.new_seed()
|
||||
|
||||
# Noise will be generated by the Img2Img generator when called
|
||||
for _ in trange(iterations, desc='Generating'):
|
||||
# make_image will call Img2Img which will do the equivalent of get_noise itself
|
||||
image = make_image()
|
||||
results.append([image, seed])
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, prompt_in=prompt)
|
||||
seed = self.new_seed()
|
||||
return results
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_img,
|
||||
strength,
|
||||
width,
|
||||
height,
|
||||
embiggen,
|
||||
embiggen_tiles,
|
||||
step_callback=None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models"
|
||||
|
||||
# Construct embiggen arg array, and sanity check arguments
|
||||
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
||||
embiggen = [1.0] # If not specified, assume no scaling
|
||||
elif embiggen[0] < 0:
|
||||
embiggen[0] = 1.0
|
||||
print(
|
||||
'>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !')
|
||||
if len(embiggen) < 2:
|
||||
embiggen.append(0.75)
|
||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||
embiggen[1] = 0.75
|
||||
print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !')
|
||||
if len(embiggen) < 3:
|
||||
embiggen.append(0.25)
|
||||
elif embiggen[2] < 0:
|
||||
embiggen[2] = 0.25
|
||||
print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !')
|
||||
|
||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||
# and then sort them, because... people.
|
||||
if embiggen_tiles:
|
||||
embiggen_tiles = list(map(lambda n: n-1, embiggen_tiles))
|
||||
embiggen_tiles.sort()
|
||||
|
||||
if strength >= 0.5:
|
||||
print(f'* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45.')
|
||||
|
||||
# Prep img2img generator, since we wrap over it
|
||||
gen_img2img = Img2Img(self.model,self.precision)
|
||||
|
||||
# Open original init image (not a tensor) to manipulate
|
||||
initsuperimage = Image.open(init_img)
|
||||
|
||||
with Image.open(init_img) as img:
|
||||
initsuperimage = img.convert('RGB')
|
||||
|
||||
# Size of the target super init image in pixels
|
||||
initsuperwidth, initsuperheight = initsuperimage.size
|
||||
|
||||
# Increase by scaling factor if not already resized, using ESRGAN as able
|
||||
if embiggen[0] != 1.0:
|
||||
initsuperwidth = round(initsuperwidth*embiggen[0])
|
||||
initsuperheight = round(initsuperheight*embiggen[0])
|
||||
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
||||
from ldm.invoke.restoration.realesrgan import ESRGAN
|
||||
esrgan = ESRGAN()
|
||||
print(
|
||||
f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
|
||||
if embiggen[0] > 2:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
4, # upscale scale
|
||||
)
|
||||
else:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
2, # upscale scale
|
||||
)
|
||||
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
|
||||
# but from personal experiance it doesn't greatly improve anything after 4x
|
||||
# Resize to target scaling factor resolution
|
||||
initsuperimage = initsuperimage.resize(
|
||||
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS)
|
||||
|
||||
# Use width and height as tile widths and height
|
||||
# Determine buffer size in pixels
|
||||
if embiggen[2] < 1:
|
||||
if embiggen[2] < 0:
|
||||
embiggen[2] = 0
|
||||
overlap_size_x = round(embiggen[2] * width)
|
||||
overlap_size_y = round(embiggen[2] * height)
|
||||
else:
|
||||
overlap_size_x = round(embiggen[2])
|
||||
overlap_size_y = round(embiggen[2])
|
||||
|
||||
# With overall image width and height known, determine how many tiles we need
|
||||
def ceildiv(a, b):
|
||||
return -1 * (-a // b)
|
||||
|
||||
# X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count)
|
||||
# (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill
|
||||
# (width - overlap_size_x) is how much new we can fill with a single tile
|
||||
emb_tiles_x = 1
|
||||
emb_tiles_y = 1
|
||||
if (initsuperwidth - width) > 0:
|
||||
emb_tiles_x = ceildiv(initsuperwidth - width,
|
||||
width - overlap_size_x) + 1
|
||||
if (initsuperheight - height) > 0:
|
||||
emb_tiles_y = ceildiv(initsuperheight - height,
|
||||
height - overlap_size_y) + 1
|
||||
# Sanity
|
||||
assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.'
|
||||
|
||||
# Prep alpha layers --------------
|
||||
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
|
||||
# agradientL is Left-side transparent
|
||||
agradientL = Image.linear_gradient('L').rotate(
|
||||
90).resize((overlap_size_x, height))
|
||||
# agradientT is Top-side transparent
|
||||
agradientT = Image.linear_gradient('L').resize((width, overlap_size_y))
|
||||
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
|
||||
agradientC = Image.new('L', (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
# Find distance to lower right corner (numpy takes arrays)
|
||||
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
|
||||
# Clamp values to max 255
|
||||
if distanceToLR > 255:
|
||||
distanceToLR = 255
|
||||
#Place the pixel as invert of distance
|
||||
agradientC.putpixel((x, y), round(255 - distanceToLR))
|
||||
|
||||
# Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
|
||||
# Fits for a left-fading gradient on the bottom side and full opacity on the right side.
|
||||
agradientAsymC = Image.new('L', (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
value = round(max(0, x-(255-y)) * (255 / max(1,y)))
|
||||
#Clamp values
|
||||
value = max(0, value)
|
||||
value = min(255, value)
|
||||
agradientAsymC.putpixel((x, y), value)
|
||||
|
||||
# Create alpha layers default fully white
|
||||
alphaLayerL = Image.new("L", (width, height), 255)
|
||||
alphaLayerT = Image.new("L", (width, height), 255)
|
||||
alphaLayerLTC = Image.new("L", (width, height), 255)
|
||||
# Paste gradients into alpha layers
|
||||
alphaLayerL.paste(agradientL, (0, 0))
|
||||
alphaLayerT.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientL, (0, 0))
|
||||
alphaLayerLTC.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
||||
# make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
|
||||
# to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
|
||||
alphaLayerTaC = alphaLayerT.copy()
|
||||
alphaLayerTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
alphaLayerLTaC = alphaLayerLTC.copy()
|
||||
alphaLayerLTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
|
||||
if embiggen_tiles:
|
||||
# Individual unconnected sides
|
||||
alphaLayerR = Image.new("L", (width, height), 255)
|
||||
alphaLayerR.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerB = Image.new("L", (width, height), 255)
|
||||
alphaLayerB.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerTB = Image.new("L", (width, height), 255)
|
||||
alphaLayerTB.paste(agradientT, (0, 0))
|
||||
alphaLayerTB.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerLR = Image.new("L", (width, height), 255)
|
||||
alphaLayerLR.paste(agradientL, (0, 0))
|
||||
alphaLayerLR.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
|
||||
# Sides and corner Layers
|
||||
alphaLayerRBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRBC.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerRBC.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerRBC.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerLBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerLBC.paste(agradientL, (0, 0))
|
||||
alphaLayerLBC.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerLBC.paste(agradientC.rotate(90).resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, height - overlap_size_y))
|
||||
alphaLayerRTC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRTC.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerRTC.paste(agradientT, (0, 0))
|
||||
alphaLayerRTC.paste(agradientC.rotate(270).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
|
||||
# All but X layers
|
||||
alphaLayerABT = Image.new("L", (width, height), 255)
|
||||
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABT.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerABT.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerABL = Image.new("L", (width, height), 255)
|
||||
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABL.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerABL.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerABR = Image.new("L", (width, height), 255)
|
||||
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABR.paste(agradientT, (0, 0))
|
||||
alphaLayerABR.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerABB = Image.new("L", (width, height), 255)
|
||||
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABB.paste(agradientL, (0, 0))
|
||||
alphaLayerABB.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
|
||||
# All-around layer
|
||||
alphaLayerAA = Image.new("L", (width, height), 255)
|
||||
alphaLayerAA.paste(alphaLayerABT, (0, 0))
|
||||
alphaLayerAA.paste(agradientT, (0, 0))
|
||||
alphaLayerAA.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerAA.paste(agradientC.rotate(270).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
|
||||
# Clean up temporary gradients
|
||||
del agradientL
|
||||
del agradientT
|
||||
del agradientC
|
||||
|
||||
def make_image():
|
||||
# Make main tiles -------------------------------------------------
|
||||
if embiggen_tiles:
|
||||
print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...')
|
||||
else:
|
||||
print(
|
||||
f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...')
|
||||
|
||||
emb_tile_store = []
|
||||
# Although we could use the same seed for every tile for determinism, at higher strengths this may
|
||||
# produce duplicated structures for each tile and make the tiling effect more obvious
|
||||
# instead track and iterate a local seed we pass to Img2Img
|
||||
seed = self.seed
|
||||
seedintlimit = np.iinfo(np.uint32).max - 1 # only retreive this one from numpy
|
||||
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
# Don't iterate on first tile
|
||||
if tile != 0:
|
||||
if seed < seedintlimit:
|
||||
seed += 1
|
||||
else:
|
||||
seed = 0
|
||||
|
||||
# Determine if this is a re-run and replace
|
||||
if embiggen_tiles and not tile in embiggen_tiles:
|
||||
continue
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
# Determine bounds to cut up the init image
|
||||
# Determine upper-left point
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i * (width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
right = left + width
|
||||
bottom = top + height
|
||||
|
||||
# Cropped image of above dimension (does not modify the original)
|
||||
newinitimage = initsuperimage.crop((left, top, right, bottom))
|
||||
# DEBUG:
|
||||
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
|
||||
# newinitimage.save(newinitimagepath)
|
||||
|
||||
if embiggen_tiles:
|
||||
print(
|
||||
f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)')
|
||||
else:
|
||||
print(
|
||||
f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')
|
||||
|
||||
# create a torch tensor from an Image
|
||||
newinitimage = np.array(
|
||||
newinitimage).astype(np.float32) / 255.0
|
||||
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
|
||||
newinitimage = torch.from_numpy(newinitimage)
|
||||
newinitimage = 2.0 * newinitimage - 1.0
|
||||
newinitimage = newinitimage.to(self.model.device)
|
||||
clear_cuda_cache = kwargs['clear_cuda_cache'] if 'clear_cuda_cache' in kwargs else None
|
||||
|
||||
tile_results = gen_img2img.generate(
|
||||
prompt,
|
||||
iterations = 1,
|
||||
seed = seed,
|
||||
sampler = sampler,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
conditioning = conditioning,
|
||||
ddim_eta = ddim_eta,
|
||||
image_callback = None, # called only after the final image is generated
|
||||
step_callback = step_callback, # called after each intermediate image is generated
|
||||
width = width,
|
||||
height = height,
|
||||
init_image = newinitimage, # notice that init_image is different from init_img
|
||||
mask_image = None,
|
||||
strength = strength,
|
||||
clear_cuda_cache = clear_cuda_cache
|
||||
)
|
||||
|
||||
emb_tile_store.append(tile_results[0][0])
|
||||
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
|
||||
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
|
||||
del newinitimage
|
||||
|
||||
# Sanity check we have them all
|
||||
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)):
|
||||
outputsuperimage = Image.new(
|
||||
"RGBA", (initsuperwidth, initsuperheight))
|
||||
if embiggen_tiles:
|
||||
outputsuperimage.alpha_composite(
|
||||
initsuperimage.convert('RGBA'), (0, 0))
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
if embiggen_tiles:
|
||||
if tile in embiggen_tiles:
|
||||
intileimage = emb_tile_store.pop(0)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
intileimage = emb_tile_store[tile]
|
||||
intileimage = intileimage.convert('RGBA')
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles:
|
||||
left = 0
|
||||
top = 0
|
||||
else:
|
||||
# Determine upper-left point
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i *
|
||||
(width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
# Handle gradients for various conditions
|
||||
# Handle emb_rerun case
|
||||
if embiggen_tiles:
|
||||
# top of image
|
||||
if emb_row_i == 0:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerB)
|
||||
# Otherwise do nothing on this tile
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRBC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerLR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABT)
|
||||
# bottom of image
|
||||
elif emb_row_i == emb_tiles_y - 1:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
# No tiles to look ahead to
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
# vertical middle of image
|
||||
else:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTB)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABL)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerAA)
|
||||
# Handle normal tiling case (much simpler - since we tile left to right, top to bottom)
|
||||
else:
|
||||
if emb_row_i == 0 and emb_column_i >= 1:
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
elif emb_row_i >= 1 and emb_column_i == 0:
|
||||
if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerT)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
# Layer tile onto final image
|
||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||
else:
|
||||
print('Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.')
|
||||
|
||||
# after internal loops and patching up return Embiggen image
|
||||
return outputsuperimage
|
||||
# end of function declaration
|
||||
return make_image
|
69
invokeai/generator/img2img.py
Normal file
69
invokeai/generator/img2img.py
Normal file
@ -0,0 +1,69 @@
|
||||
'''
|
||||
invokeai.backend.generator.img2img descends from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
from diffusers import logging
|
||||
|
||||
from .base import Generator
|
||||
from .diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
|
||||
from ..models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,warmup=0.2,perlin=0.0,
|
||||
h_symmetry_time_pct=None,v_symmetry_time_pct=None,attention_maps_callback=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct
|
||||
)
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
|
||||
def make_image(x_T):
|
||||
# FIXME: use x_T for initial seeded noise
|
||||
# We're not at the moment because the pipeline automatically resizes init_image if
|
||||
# necessary, which the x_T input might not match.
|
||||
logging.set_verbosity_error() # quench safety check warnings
|
||||
pipeline_output = pipeline.img2img_from_embeddings(
|
||||
init_image, strength, steps, conditioning_data,
|
||||
noise_func=self.get_noise_like,
|
||||
callback=step_callback
|
||||
)
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
if device.type == 'mps':
|
||||
x = torch.randn_like(like, device='cpu').to(device)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = like.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
324
invokeai/generator/inpaint.py
Normal file
324
invokeai/generator/inpaint.py
Normal file
@ -0,0 +1,324 @@
|
||||
'''
|
||||
invokeai.backend.generator.inpaint descends from ldm.invoke.generator
|
||||
'''
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import PIL
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
|
||||
from .diffusers_pipeline import image_resized_to_grid_as_tensor, StableDiffusionGeneratorPipeline, \
|
||||
ConditioningData
|
||||
from .img2img import Img2Img
|
||||
from ldm.invoke.patchmatch import PatchMatch
|
||||
from ldm.util import debug_image
|
||||
|
||||
|
||||
def infill_methods()->list[str]:
|
||||
methods = [
|
||||
"tile",
|
||||
"solid",
|
||||
]
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.insert(0, 'patchmatch')
|
||||
return methods
|
||||
|
||||
class Inpaint(Img2Img):
|
||||
def __init__(self, model, precision):
|
||||
self.inpaint_height = 0
|
||||
self.inpaint_width = 0
|
||||
self.enable_image_debugging = False
|
||||
self.init_latent = None
|
||||
self.pil_image = None
|
||||
self.pil_mask = None
|
||||
self.mask_blur_radius = 0
|
||||
self.infill_method = None
|
||||
super().__init__(model, precision)
|
||||
|
||||
# Outpaint support code
|
||||
def get_tile_images(self, image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False
|
||||
)
|
||||
|
||||
def infill_patchmatch(self, im: Image.Image) -> Image:
|
||||
if im.mode != 'RGBA':
|
||||
return im
|
||||
|
||||
# Skip patchmatch if patchmatch isn't available
|
||||
if not PatchMatch.patchmatch_available():
|
||||
return im
|
||||
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3)
|
||||
im_patched = Image.fromarray(im_patched_np, mode = 'RGB')
|
||||
return im_patched
|
||||
|
||||
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != 'RGBA':
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = self.get_tile_images(a,*tile_size).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:,:,:,:,3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = (tiles_mask > 0)
|
||||
tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
rng = np.random.default_rng(seed = seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1,2)
|
||||
st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4]))
|
||||
si = Image.fromarray(st, mode='RGBA')
|
||||
|
||||
return si
|
||||
|
||||
|
||||
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
|
||||
npimg = np.asarray(mask, dtype=np.uint8)
|
||||
|
||||
# Detect any partially transparent regions
|
||||
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
||||
|
||||
# Detect hard edges
|
||||
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
|
||||
|
||||
# Combine
|
||||
npmask = npgradient + npedge
|
||||
|
||||
# Expand
|
||||
npmask = cv2.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
|
||||
|
||||
new_mask = Image.fromarray(npmask)
|
||||
|
||||
if edge_blur > 0:
|
||||
new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur))
|
||||
|
||||
return ImageOps.invert(new_mask)
|
||||
|
||||
|
||||
def seam_paint(self, im: Image.Image, seam_size: int, seam_blur: int, prompt, sampler, steps, cfg_scale, ddim_eta,
|
||||
conditioning, strength, noise, infill_method, step_callback) -> Image.Image:
|
||||
hard_mask = self.pil_image.split()[-1].copy()
|
||||
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
||||
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image = im.copy().convert('RGBA'),
|
||||
mask_image = mask,
|
||||
strength = strength,
|
||||
mask_blur_radius = 0,
|
||||
seam_size = 0,
|
||||
step_callback = step_callback,
|
||||
inpaint_width = im.width,
|
||||
inpaint_height = im.height,
|
||||
infill_method = infill_method
|
||||
)
|
||||
|
||||
seam_noise = self.get_noise(im.width, im.height)
|
||||
|
||||
result = make_image(seam_noise)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,
|
||||
init_image: PIL.Image.Image | torch.FloatTensor,
|
||||
mask_image: PIL.Image.Image | torch.FloatTensor,
|
||||
strength: float,
|
||||
mask_blur_radius: int = 8,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
tile_size: int = 32,
|
||||
step_callback=None,
|
||||
inpaint_replace=False, enable_image_debugging=False,
|
||||
infill_method = None,
|
||||
inpaint_width=None,
|
||||
inpaint_height=None,
|
||||
inpaint_fill:tuple(int)=(0x7F, 0x7F, 0x7F, 0xFF),
|
||||
attention_maps_callback=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and
|
||||
the initial image + mask. Return value depends on the seed at
|
||||
the time you call it. kwargs are 'init_latent' and 'strength'
|
||||
"""
|
||||
|
||||
self.enable_image_debugging = enable_image_debugging
|
||||
infill_method = infill_method or infill_methods()[0]
|
||||
self.infill_method = infill_method
|
||||
|
||||
self.inpaint_width = inpaint_width
|
||||
self.inpaint_height = inpaint_height
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
self.pil_image = init_image.copy()
|
||||
|
||||
# Do infill
|
||||
if infill_method == 'patchmatch' and PatchMatch.patchmatch_available():
|
||||
init_filled = self.infill_patchmatch(self.pil_image.copy())
|
||||
elif infill_method == 'tile':
|
||||
init_filled = self.tile_fill_missing(
|
||||
self.pil_image.copy(),
|
||||
seed = self.seed,
|
||||
tile_size = tile_size
|
||||
)
|
||||
elif infill_method == 'solid':
|
||||
solid_bg = PIL.Image.new("RGBA", init_image.size, inpaint_fill)
|
||||
init_filled = PIL.Image.alpha_composite(solid_bg, init_image)
|
||||
else:
|
||||
raise ValueError(f"Non-supported infill type {infill_method}", infill_method)
|
||||
init_filled.paste(init_image, (0,0), init_image.split()[-1])
|
||||
|
||||
# Resize if requested for inpainting
|
||||
if inpaint_width and inpaint_height:
|
||||
init_filled = init_filled.resize((inpaint_width, inpaint_height))
|
||||
|
||||
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
|
||||
|
||||
# Create init tensor
|
||||
init_image = image_resized_to_grid_as_tensor(init_filled.convert('RGB'))
|
||||
|
||||
if isinstance(mask_image, PIL.Image.Image):
|
||||
self.pil_mask = mask_image.copy()
|
||||
debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging)
|
||||
|
||||
init_alpha = self.pil_image.getchannel("A")
|
||||
if mask_image.mode != "L":
|
||||
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_image = ImageChops.multiply(mask_image, init_alpha)
|
||||
self.pil_mask = mask_image
|
||||
|
||||
# Resize if requested for inpainting
|
||||
if inpaint_width and inpaint_height:
|
||||
mask_image = mask_image.resize((inpaint_width, inpaint_height))
|
||||
|
||||
debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging)
|
||||
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
else:
|
||||
mask: torch.FloatTensor = mask_image
|
||||
|
||||
self.mask_blur_radius = mask_blur_radius
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
conditioning_data = (ConditioningData(uc, c, cfg_scale)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
|
||||
def make_image(x_T):
|
||||
pipeline_output = pipeline.inpaint_from_embeddings(
|
||||
init_image=init_image,
|
||||
mask=1 - mask, # expects white means "paint here."
|
||||
strength=strength,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
noise_func=self.get_noise_like,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
result = self.postprocess_size_and_mask(pipeline.numpy_to_pil(pipeline_output.images)[0])
|
||||
|
||||
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
||||
if seam_size > 0:
|
||||
old_image = self.pil_image or init_image
|
||||
old_mask = self.pil_mask or mask_image
|
||||
|
||||
result = self.seam_paint(result, seam_size, seam_blur, prompt, sampler, seam_steps, cfg_scale, ddim_eta,
|
||||
conditioning, seam_strength, x_T, infill_method, step_callback)
|
||||
|
||||
# Restore original settings
|
||||
self.get_make_image(prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,
|
||||
old_image,
|
||||
old_mask,
|
||||
strength,
|
||||
mask_blur_radius, seam_size, seam_blur, seam_strength,
|
||||
seam_steps, tile_size, step_callback,
|
||||
inpaint_replace, enable_image_debugging,
|
||||
inpaint_width = inpaint_width,
|
||||
inpaint_height = inpaint_height,
|
||||
infill_method = infill_method,
|
||||
**kwargs)
|
||||
|
||||
return result
|
||||
|
||||
return make_image
|
||||
|
||||
|
||||
def sample_to_image(self, samples)->Image.Image:
|
||||
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||
return self.postprocess_size_and_mask(gen_result)
|
||||
|
||||
|
||||
def postprocess_size_and_mask(self, gen_result: Image.Image) -> Image.Image:
|
||||
debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging)
|
||||
|
||||
# Resize if necessary
|
||||
if self.inpaint_width and self.inpaint_height:
|
||||
gen_result = gen_result.resize(self.pil_image.size)
|
||||
|
||||
if self.pil_image is None or self.pil_mask is None:
|
||||
return gen_result
|
||||
|
||||
corrected_result = self.repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
|
||||
debug_image(corrected_result, "corrected_result", debug_status=self.enable_image_debugging)
|
||||
|
||||
return corrected_result
|
173
invokeai/generator/omnibus.py
Normal file
173
invokeai/generator/omnibus.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from einops import repeat
|
||||
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.invoke.generator.txt2img import Txt2Img
|
||||
|
||||
|
||||
class Omnibus(Img2Img,Txt2Img):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.pil_mask = None
|
||||
self.pil_image = None
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width,
|
||||
height,
|
||||
init_image = None,
|
||||
mask_image = None,
|
||||
strength = None,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
perlin=0.0,
|
||||
mask_blur_radius: int = 8,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
self.perlin = perlin
|
||||
num_samples = 1
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
if isinstance(init_image, Image.Image):
|
||||
self.pil_image = init_image
|
||||
if init_image.mode != 'RGB':
|
||||
init_image = init_image.convert('RGB')
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
|
||||
if isinstance(mask_image, Image.Image):
|
||||
self.pil_mask = mask_image
|
||||
|
||||
mask_image = ImageChops.multiply(mask_image.convert('L'), self.pil_image.split()[-1])
|
||||
mask_image = self._image_to_tensor(ImageOps.invert(mask_image), normalize=False)
|
||||
|
||||
self.mask_blur_radius = mask_blur_radius
|
||||
|
||||
if init_image is not None and mask_image is not None: # inpainting
|
||||
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
|
||||
|
||||
elif init_image is not None: # img2img
|
||||
scope = choose_autocast(self.precision)
|
||||
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
# create a completely black mask (1s)
|
||||
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
|
||||
# and the masked image is just a copy of the original
|
||||
masked_image = init_image
|
||||
|
||||
else: # txt2img
|
||||
init_image = torch.zeros(1, 3, height, width, device=self.model.device)
|
||||
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
|
||||
masked_image = init_image
|
||||
|
||||
self.init_latent = init_image
|
||||
height = init_image.shape[2]
|
||||
width = init_image.shape[3]
|
||||
model = self.model
|
||||
|
||||
def make_image(x_T):
|
||||
with torch.no_grad():
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
|
||||
batch = self.make_batch_sd(
|
||||
init_image,
|
||||
mask_image,
|
||||
masked_image,
|
||||
prompt=prompt,
|
||||
device=model.device,
|
||||
num_samples=num_samples,
|
||||
)
|
||||
|
||||
c = model.cond_stage_model.encode(batch["txt"])
|
||||
c_cat = list()
|
||||
for ck in model.concat_keys:
|
||||
cc = batch[ck].float()
|
||||
if ck != model.masked_image_key:
|
||||
bchw = [num_samples, 4, height//8, width//8]
|
||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||
else:
|
||||
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
|
||||
c_cat.append(cc)
|
||||
c_cat = torch.cat(c_cat, dim=1)
|
||||
|
||||
# cond
|
||||
cond={"c_concat": [c_cat], "c_crossattn": [c]}
|
||||
|
||||
# uncond cond
|
||||
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||
shape = [model.channels, height//8, width//8]
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
x_T = x_T,
|
||||
conditioning = cond,
|
||||
shape = shape,
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc_full,
|
||||
eta = 1.0,
|
||||
img_callback = step_callback,
|
||||
threshold = threshold,
|
||||
)
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to("cpu")
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
def make_batch_sd(
|
||||
self,
|
||||
image,
|
||||
mask,
|
||||
masked_image,
|
||||
prompt,
|
||||
device,
|
||||
num_samples=1):
|
||||
batch = {
|
||||
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"txt": num_samples * [prompt],
|
||||
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
}
|
||||
return batch
|
||||
|
||||
def get_noise(self, width:int, height:int):
|
||||
if self.init_latent is not None:
|
||||
height = self.init_latent.shape[2]
|
||||
width = self.init_latent.shape[3]
|
||||
return Txt2Img.get_noise(self,width,height)
|
||||
|
||||
|
||||
def sample_to_image(self, samples)->Image.Image:
|
||||
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||
|
||||
if self.pil_image is None or self.pil_mask is None:
|
||||
return gen_result
|
||||
if self.pil_image.size != self.pil_mask.size:
|
||||
return gen_result
|
||||
|
||||
corrected_result = super(Img2Img, self).repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
|
||||
|
||||
return corrected_result
|
60
invokeai/generator/txt2img.py
Normal file
60
invokeai/generator/txt2img.py
Normal file
@ -0,0 +1,60 @@
|
||||
'''
|
||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
'''
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from .base import Generator
|
||||
from .diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
|
||||
from ..models import PostprocessingSettings
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,width,height,step_callback=None,threshold=0.0,warmup=0.2,perlin=0.0,
|
||||
h_symmetry_time_pct=None,v_symmetry_time_pct=None,attention_maps_callback=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct
|
||||
)
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
def make_image(x_T) -> PIL.Image.Image:
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
latents=torch.zeros_like(x_T,dtype=self.torch_dtype()),
|
||||
noise=x_T,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
|
||||
|
||||
|
163
invokeai/generator/txt2img2img.py
Normal file
163
invokeai/generator/txt2img2img.py
Normal file
@ -0,0 +1,163 @@
|
||||
'''
|
||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
'''
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
|
||||
|
||||
from .base import Generator
|
||||
from .diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \
|
||||
ConditioningData
|
||||
from ..models import PostprocessingSettings
|
||||
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # for get_noise()
|
||||
|
||||
def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta,
|
||||
conditioning, width:int, height:int, strength:float,
|
||||
step_callback:Optional[Callable]=None, threshold=0.0, warmup=0.2, perlin=0.0,
|
||||
h_symmetry_time_pct=None, v_symmetry_time_pct=None, attention_maps_callback=None, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
postprocessing_settings = PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=0.2,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct
|
||||
)
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
def make_image(x_T):
|
||||
|
||||
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(x_T),
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
noise=x_T,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# Get our initial generation width and height directly from the latent output so
|
||||
# the message below is accurate.
|
||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
||||
print(
|
||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
first_pass_latent_output,
|
||||
size=(height // self.downsampling_factor, width // self.downsampling_factor),
|
||||
mode="bilinear"
|
||||
)
|
||||
|
||||
# Free up memory from the last generation.
|
||||
clear_cuda_cache = kwargs['clear_cuda_cache'] or None
|
||||
if clear_cuda_cache is not None:
|
||||
clear_cuda_cache()
|
||||
|
||||
second_pass_noise = self.get_noise_like(resized_latents, override_perlin=True)
|
||||
|
||||
# Clear symmetry for the second pass
|
||||
from dataclasses import replace
|
||||
new_postprocessing_settings = replace(conditioning_data.postprocessing_settings, h_symmetry_time_pct=None)
|
||||
new_postprocessing_settings = replace(new_postprocessing_settings, v_symmetry_time_pct=None)
|
||||
new_conditioning_data = replace(conditioning_data, postprocessing_settings=new_postprocessing_settings)
|
||||
|
||||
verbosity = get_verbosity()
|
||||
set_verbosity_error()
|
||||
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
|
||||
resized_latents,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=new_conditioning_data,
|
||||
strength=strength,
|
||||
noise=second_pass_noise,
|
||||
callback=step_callback)
|
||||
set_verbosity(verbosity)
|
||||
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
|
||||
# FIXME: do we really need something entirely different for the inpainting model?
|
||||
|
||||
# in the case of the inpainting model being loaded, the trick of
|
||||
# providing an interpolated latent doesn't work, so we transiently
|
||||
# create a 512x512 PIL image, upscale it, and run the inpainting
|
||||
# over it in img2img mode. Because the inpaing model is so conservative
|
||||
# it doesn't change the image (much)
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor, override_perlin: bool=False):
|
||||
device = like.device
|
||||
if device.type == 'mps':
|
||||
x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
|
||||
if self.perlin > 0.0 and override_perlin == False:
|
||||
shape = like.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height,scale = True):
|
||||
# print(f"Get noise: {width}x{height}")
|
||||
if scale:
|
||||
# Scale the input width and height for the initial generation
|
||||
# Make their area equivalent to the model's resolution area (e.g. 512*512 = 262144),
|
||||
# while keeping the minimum dimension at least 0.5 * resolution (e.g. 512*0.5 = 256)
|
||||
|
||||
aspect = width / height
|
||||
dimension = self.model.unet.config.sample_size * self.model.vae_scale_factor
|
||||
min_dimension = math.floor(dimension * 0.5)
|
||||
model_area = dimension * dimension # hardcoded for now since all models are trained on square images
|
||||
|
||||
if aspect > 1.0:
|
||||
init_height = max(min_dimension, math.sqrt(model_area / aspect))
|
||||
init_width = init_height * aspect
|
||||
else:
|
||||
init_width = max(min_dimension, math.sqrt(model_area * aspect))
|
||||
init_height = init_width / aspect
|
||||
|
||||
scaled_width, scaled_height = trim_to_multiple_of(math.floor(init_width), math.floor(init_height))
|
||||
|
||||
else:
|
||||
scaled_width = width
|
||||
scaled_height = height
|
||||
|
||||
device = self.model.device
|
||||
channels = self.latent_channels
|
||||
if channels == 9:
|
||||
channels = 4 # we don't really want noise for all the mask channels
|
||||
shape = (1, channels,
|
||||
scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor)
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
tensor = torch.empty(size=shape, device='cpu')
|
||||
tensor = self.get_noise_like(like=tensor).to(device)
|
||||
else:
|
||||
tensor = torch.empty(size=shape, device=device)
|
||||
tensor = self.get_noise_like(like=tensor)
|
||||
return tensor
|
Binary file not shown.
Loading…
Reference in New Issue
Block a user