diffusers: txt2img2img (hires_fix)

with so much slicing and dicing of pipeline methods to stitch them together
This commit is contained in:
Kevin Turner
2022-12-06 19:16:28 -08:00
parent bf6376417a
commit 04a5bc938e
3 changed files with 149 additions and 152 deletions

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import secrets import secrets
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union, Callable from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any, ParamSpec
import PIL.Image import PIL.Image
import einops import einops
@ -11,11 +11,11 @@ import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers.models import attention from diffusers.models import attention
from ldm.models.diffusion.cross_attention_control import InvokeAIDiffusersCrossAttention from ...models.diffusion import cross_attention_control
# monkeypatch diffusers CrossAttention 🙈 # monkeypatch diffusers CrossAttention 🙈
# this is to make prompt2prompt and (future) attention maps work # this is to make prompt2prompt and (future) attention maps work
attention.CrossAttention = InvokeAIDiffusersCrossAttention attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@ -126,6 +126,10 @@ class AddsMaskGuidance:
return masked_input return masked_input
def trim_to_multiple_of(*args, multiple_of=8):
return tuple((x - x % multiple_of) for x in args)
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor: def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
""" """
@ -133,8 +137,7 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
:param normalize: scale the range to [-1, 1] instead of [0, 1] :param normalize: scale the range to [-1, 1] instead of [0, 1]
:param multiple_of: resize the input so both dimensions are a multiple of this :param multiple_of: resize the input so both dimensions are a multiple of this
""" """
w, h = image.size w, h = trim_to_multiple_of(*image.size)
w, h = map(lambda x: x - x % multiple_of, (w, h)) # resize to integer multiple of 8
transformation = T.Compose([ transformation = T.Compose([
T.Resize((h, w), T.InterpolationMode.LANCZOS), T.Resize((h, w), T.InterpolationMode.LANCZOS),
T.ToTensor(), T.ToTensor(),
@ -148,6 +151,26 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
def is_inpainting_model(unet: UNet2DConditionModel): def is_inpainting_model(unet: UNet2DConditionModel):
return unet.conv_in.in_channels == 9 return unet.conv_in.in_channels == 9
CallbackType = TypeVar('CallbackType')
ReturnType = TypeVar('ReturnType')
ParamType = ParamSpec('ParamType')
@dataclass(frozen=True)
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
generator_method: Callable[ParamType, ReturnType]
callback_arg_type: Type[CallbackType]
def __call__(self, *args: ParamType.args,
callback:Callable[[CallbackType], Any]=None,
**kwargs: ParamType.kwargs) -> ReturnType:
result = None
for result in self.generator_method(*args, **kwargs):
if callback is not None and isinstance(result, self.callback_arg_type):
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r""" r"""
@ -250,6 +273,21 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
raise AssertionError("why was that an empty generator?") raise AssertionError("why was that an empty generator?")
return result return result
def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
guidance_scale: float,
*, callback: Callable[[PipelineIntermediateState], None]=None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
run_id=None,
**extra_step_kwargs) -> PipelineIntermediateState:
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
f = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
run_id=run_id,
callback=callback,
**extra_step_kwargs)
def generate( def generate(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
@ -303,19 +341,42 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timesteps = None, timesteps = None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
**extra_step_kwargs): **extra_step_kwargs):
latents = yield from self.generate_latents_from_embeddings(latents, text_embeddings, unconditioned_embeddings,
guidance_scale, run_id=run_id, extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps, additional_guidance=additional_guidance, **extra_step_kwargs)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(latents)
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
yield self.check_for_safety(output, dtype=text_embeddings.dtype)
def generate_latents_from_embeddings(
self,
latents: torch.Tensor,
text_embeddings: torch.Tensor,
unconditioned_embeddings: torch.Tensor,
guidance_scale: float,
*,
run_id: str = None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
timesteps = None,
additional_guidance: List[Callable] = None,
**extra_step_kwargs
):
if run_id is None: if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH) run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
step_count=len(self.scheduler.timesteps)) step_count=len(self.scheduler.timesteps))
else: else:
self.invokeai_diffuser.remove_cross_attention_control() self.invokeai_diffuser.remove_cross_attention_control()
if timesteps is None: if timesteps is None:
# NOTE: Depends on scheduler being already initialized!
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
@ -326,7 +387,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
batch_size = latents.shape[0] batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0], batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device) dtype=timesteps.dtype, device=self.unet.device)
# NOTE: Depends on scheduler being already initialized!
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t) batched_t.fill_(t)
step_output = self.step(batched_t, latents, guidance_scale, step_output = self.step(batched_t, latents, guidance_scale,
@ -337,14 +398,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
predicted_original = getattr(step_output, 'pred_original_sample', None) predicted_original = getattr(step_output, 'pred_original_sample', None)
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
predicted_original=predicted_original) predicted_original=predicted_original)
return latents
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(latents)
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
yield self.check_for_safety(output, dtype=text_embeddings.dtype)
@torch.inference_mode() @torch.inference_mode()
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float, def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
@ -396,34 +450,38 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
**extra_step_kwargs) -> StableDiffusionPipelineOutput: **extra_step_kwargs) -> StableDiffusionPipelineOutput:
device = self.unet.device device = self.unet.device
latents_dtype = self.unet.dtype latents_dtype = self.unet.dtype
batch_size = 1
num_images_per_prompt = 1
if isinstance(init_image, PIL.Image.Image): if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB')) init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
if init_image.dim() == 3: if init_image.dim() == 3:
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w') init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
# 6. Prepare latent variables
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
result = self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, text_embeddings,
unconditioned_embeddings, guidance_scale, strength,
extra_conditioning_info, noise_func, run_id, callback,
**extra_step_kwargs)
return result
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings,
unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info,
noise_func, run_id=None, callback=None, **extra_step_kwargs):
device = self.unet.device
batch_size = initial_latents.size(0)
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device) img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device) timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size)
latents = self.noise_latents_for_time(initial_latents, latent_timestep, noise_func=noise_func)
# 6. Prepare latent variables f = GeneratorToCallbackinator(self.generate_from_embeddings, PipelineIntermediateState)
latents, _ = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
result = None timesteps=timesteps,
for result in self.generate_from_embeddings( callback=callback,
latents, text_embeddings, unconditioned_embeddings, guidance_scale, run_id=run_id, **extra_step_kwargs)
extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps,
run_id=run_id, **extra_step_kwargs):
if callback is not None and isinstance(result, PipelineIntermediateState):
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
def inpaint_from_embeddings( def inpaint_from_embeddings(
self, self,
@ -459,7 +517,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# 6. Prepare latent variables # 6. Prepare latent variables
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
# because we have our own noise function
init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
latents = self.noise_latents_for_time(init_image_latents, latent_timestep, noise_func=noise_func)
if mask.dim() == 3: if mask.dim() == 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
@ -491,19 +552,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
finally: finally:
self.invokeai_diffuser.model_forward_callback = self._unet_forward self.invokeai_diffuser.model_forward_callback = self._unet_forward
def non_noised_latents_from_image(self, init_image, *, device, dtype):
def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> (torch.FloatTensor, torch.FloatTensor):
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
# because we have our own noise function
init_image = init_image.to(device=device, dtype=dtype) init_image = init_image.to(device=device, dtype=dtype)
with torch.inference_mode(): with torch.inference_mode():
init_latent_dist = self.vae.encode(init_image).latent_dist init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible! init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
init_latents = 0.18215 * init_latents init_latents = 0.18215 * init_latents
return init_latents
noise = noise_func(init_latents) def noise_latents_for_time(self, latents, timestep, *, noise_func):
noised_latents = self.scheduler.add_noise(init_latents, noise, timestep) noise = noise_func(latents)
return noised_latents, init_latents noised_latents = self.scheduler.add_noise(latents, noise, timestep)
return noised_latents
def check_for_safety(self, output, dtype): def check_for_safety(self, output, dtype):
with torch.inference_mode(): with torch.inference_mode():

View File

@ -29,10 +29,6 @@ class Txt2Img(Generator):
pipeline.scheduler = sampler pipeline.scheduler = sampler
def make_image(x_T) -> PIL.Image.Image: def make_image(x_T) -> PIL.Image.Image:
# FIXME: restore free_gpu_mem functionality
# if self.free_gpu_mem and self.model.model.device != self.model.device:
# self.model.model.to(self.model.device)
pipeline_output = pipeline.image_from_embeddings( pipeline_output = pipeline.image_from_embeddings(
latents=x_T, latents=x_T,
num_inference_steps=steps, num_inference_steps=steps,
@ -45,10 +41,6 @@ class Txt2Img(Generator):
# TODO: threshold = threshold, # TODO: threshold = threshold,
) )
# FIXME: restore free_gpu_mem functionality
# if self.free_gpu_mem:
# self.model.model.to("cpu")
return pipeline.numpy_to_pil(pipeline_output.images)[0] return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image return make_image

View File

@ -3,13 +3,12 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
''' '''
import math import math
from typing import Callable, Optional
import torch import torch
from PIL import Image
from ldm.invoke.generator.base import Generator from ldm.invoke.generator.base import Generator
from ldm.invoke.generator.omnibus import Omnibus from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline
from ldm.models.diffusion.ddim import DDIMSampler
class Txt2Img2Img(Generator): class Txt2Img2Img(Generator):
@ -17,9 +16,9 @@ class Txt2Img2Img(Generator):
super().__init__(model, precision) super().__init__(model, precision)
self.init_latent = None # for get_noise() self.init_latent = None # for get_noise()
@torch.no_grad() def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta,
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, conditioning, width:int, height:int, strength:float,
conditioning,width,height,strength,step_callback=None,**kwargs): step_callback:Optional[Callable]=None, **kwargs):
""" """
Returns a function returning an image derived from the prompt and the initial image Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it Return value depends on the seed at the time you call it
@ -29,125 +28,72 @@ class Txt2Img2Img(Generator):
scale_dim = min(width, height) scale_dim = min(width, height)
scale = 512 / scale_dim scale = 512 / scale_dim
init_width = math.ceil(scale * width / 64) * 64 init_width, init_height = trim_to_multiple_of(scale * width, scale * height)
init_height = math.ceil(scale * height / 64) * 64
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
@torch.no_grad()
def make_image(x_T): def make_image(x_T):
shape = [ pipeline_output = pipeline.latents_from_embeddings(
self.latent_channels, latents=x_T,
init_height // self.downsampling_factor, num_inference_steps=steps,
init_width // self.downsampling_factor, text_embeddings=c,
] unconditioned_embeddings=uc,
guidance_scale=cfg_scale,
sampler.make_schedule( callback=step_callback,
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False extra_conditioning_info=extra_conditioning_info,
# TODO: eta = ddim_eta,
# TODO: threshold = threshold,
) )
#x = self.get_noise(init_width, init_height) first_pass_latent_output = pipeline_output.latents
x = x_T
if self.free_gpu_mem and self.model.model.device != self.model.device:
self.model.model.to(self.model.device)
samples, _ = sampler.sample(
batch_size = 1,
S = steps,
x_T = x,
conditioning = c,
shape = shape,
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
eta = ddim_eta,
img_callback = step_callback,
extra_conditioning_info = extra_conditioning_info
)
print( print(
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling" f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
) )
# resizing # resizing
samples = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
samples, first_pass_latent_output,
size=(height // self.downsampling_factor, width // self.downsampling_factor), size=(height // self.downsampling_factor, width // self.downsampling_factor),
mode="bilinear" mode="bilinear"
) )
t_enc = int(strength * steps) pipeline_output = pipeline.img2img_from_latents_and_embeddings(
ddim_sampler = DDIMSampler(self.model, device=self.model.device) resized_latents,
ddim_sampler.make_schedule( num_inference_steps=steps,
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False text_embeddings=c,
) unconditioned_embeddings=uc,
guidance_scale=cfg_scale, strength=strength,
z_enc = ddim_sampler.stochastic_encode(
samples,
torch.tensor([t_enc]).to(self.model.device),
noise=self.get_noise(width,height,False)
)
# decode it
samples = ddim_sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=extra_conditioning_info,
all_timesteps_count=steps noise_func=self.get_noise_like,
) callback=step_callback)
if self.free_gpu_mem: return pipeline.numpy_to_pil(pipeline_output.images)[0]
self.model.model.to("cpu")
return self.sample_to_image(samples)
# FIXME: do we really need something entirely different for the inpainting model?
# in the case of the inpainting model being loaded, the trick of # in the case of the inpainting model being loaded, the trick of
# providing an interpolated latent doesn't work, so we transiently # providing an interpolated latent doesn't work, so we transiently
# create a 512x512 PIL image, upscale it, and run the inpainting # create a 512x512 PIL image, upscale it, and run the inpainting
# over it in img2img mode. Because the inpaing model is so conservative # over it in img2img mode. Because the inpaing model is so conservative
# it doesn't change the image (much) # it doesn't change the image (much)
def inpaint_make_image(x_T):
omnibus = Omnibus(self.model,self.precision) return make_image
result = omnibus.generate(
prompt, def get_noise_like(self, like: torch.Tensor):
sampler=sampler, device = like.device
width=init_width, if device.type == 'mps':
height=init_height, x = torch.randn_like(like, device='cpu').to(device)
step_callback=step_callback,
steps = steps,
cfg_scale = cfg_scale,
ddim_eta = ddim_eta,
conditioning = conditioning,
**kwargs
)
assert result is not None and len(result)>0,'** txt2img failed **'
image = result[0][0]
interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
print(kwargs.pop('init_image',None))
result = omnibus.generate(
prompt,
sampler=sampler,
init_image=interpolated_image,
width=width,
height=height,
seed=result[0][1],
step_callback=step_callback,
steps = steps,
cfg_scale = cfg_scale,
ddim_eta = ddim_eta,
conditioning = conditioning,
**kwargs
)
return result[0][0]
if sampler.uses_inpainting_model():
return inpaint_make_image
else: else:
return make_image x = torch.randn_like(like, device=device)
if self.perlin > 0.0:
shape = like.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x
# returns a tensor filled with random numbers from a normal distribution # returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height,scale = True): def get_noise(self,width,height,scale = True):
@ -175,4 +121,3 @@ class Txt2Img2Img(Generator):
scaled_height // self.downsampling_factor, scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor], scaled_width // self.downsampling_factor],
device=device) device=device)