diffusers: txt2img2img (hires_fix)

with so much slicing and dicing of pipeline methods to stitch them together
This commit is contained in:
Kevin Turner 2022-12-06 19:16:28 -08:00
parent bf6376417a
commit 04a5bc938e
3 changed files with 149 additions and 152 deletions

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import secrets
import warnings
from dataclasses import dataclass
from typing import List, Optional, Union, Callable
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any, ParamSpec
import PIL.Image
import einops
@ -11,11 +11,11 @@ import torch
import torchvision.transforms as T
from diffusers.models import attention
from ldm.models.diffusion.cross_attention_control import InvokeAIDiffusersCrossAttention
from ...models.diffusion import cross_attention_control
# monkeypatch diffusers CrossAttention 🙈
# this is to make prompt2prompt and (future) attention maps work
attention.CrossAttention = InvokeAIDiffusersCrossAttention
attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@ -126,6 +126,10 @@ class AddsMaskGuidance:
return masked_input
def trim_to_multiple_of(*args, multiple_of=8):
return tuple((x - x % multiple_of) for x in args)
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
"""
@ -133,8 +137,7 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
:param normalize: scale the range to [-1, 1] instead of [0, 1]
:param multiple_of: resize the input so both dimensions are a multiple of this
"""
w, h = image.size
w, h = map(lambda x: x - x % multiple_of, (w, h)) # resize to integer multiple of 8
w, h = trim_to_multiple_of(*image.size)
transformation = T.Compose([
T.Resize((h, w), T.InterpolationMode.LANCZOS),
T.ToTensor(),
@ -148,6 +151,26 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
def is_inpainting_model(unet: UNet2DConditionModel):
return unet.conv_in.in_channels == 9
CallbackType = TypeVar('CallbackType')
ReturnType = TypeVar('ReturnType')
ParamType = ParamSpec('ParamType')
@dataclass(frozen=True)
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
generator_method: Callable[ParamType, ReturnType]
callback_arg_type: Type[CallbackType]
def __call__(self, *args: ParamType.args,
callback:Callable[[CallbackType], Any]=None,
**kwargs: ParamType.kwargs) -> ReturnType:
result = None
for result in self.generator_method(*args, **kwargs):
if callback is not None and isinstance(result, self.callback_arg_type):
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r"""
@ -250,6 +273,21 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
raise AssertionError("why was that an empty generator?")
return result
def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
guidance_scale: float,
*, callback: Callable[[PipelineIntermediateState], None]=None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
run_id=None,
**extra_step_kwargs) -> PipelineIntermediateState:
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
f = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
run_id=run_id,
callback=callback,
**extra_step_kwargs)
def generate(
self,
prompt: Union[str, List[str]],
@ -303,19 +341,42 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timesteps = None,
additional_guidance: List[Callable] = None,
**extra_step_kwargs):
latents = yield from self.generate_latents_from_embeddings(latents, text_embeddings, unconditioned_embeddings,
guidance_scale, run_id=run_id, extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps, additional_guidance=additional_guidance, **extra_step_kwargs)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(latents)
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
yield self.check_for_safety(output, dtype=text_embeddings.dtype)
def generate_latents_from_embeddings(
self,
latents: torch.Tensor,
text_embeddings: torch.Tensor,
unconditioned_embeddings: torch.Tensor,
guidance_scale: float,
*,
run_id: str = None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
timesteps = None,
additional_guidance: List[Callable] = None,
**extra_step_kwargs
):
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None:
additional_guidance = []
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
step_count=len(self.scheduler.timesteps))
else:
self.invokeai_diffuser.remove_cross_attention_control()
if timesteps is None:
# NOTE: Depends on scheduler being already initialized!
timesteps = self.scheduler.timesteps
# scale the initial noise by the standard deviation required by the scheduler
@ -326,7 +387,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device)
# NOTE: Depends on scheduler being already initialized!
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(batched_t, latents, guidance_scale,
@ -337,14 +398,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
predicted_original = getattr(step_output, 'pred_original_sample', None)
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
predicted_original=predicted_original)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(latents)
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
yield self.check_for_safety(output, dtype=text_embeddings.dtype)
return latents
@torch.inference_mode()
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
@ -396,34 +450,38 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
device = self.unet.device
latents_dtype = self.unet.dtype
batch_size = 1
num_images_per_prompt = 1
if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
if init_image.dim() == 3:
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
# 6. Prepare latent variables
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
result = self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, text_embeddings,
unconditioned_embeddings, guidance_scale, strength,
extra_conditioning_info, noise_func, run_id, callback,
**extra_step_kwargs)
return result
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings,
unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info,
noise_func, run_id=None, callback=None, **extra_step_kwargs):
device = self.unet.device
batch_size = initial_latents.size(0)
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latent_timestep = timesteps[:1].repeat(batch_size)
latents = self.noise_latents_for_time(initial_latents, latent_timestep, noise_func=noise_func)
# 6. Prepare latent variables
latents, _ = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
result = None
for result in self.generate_from_embeddings(
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps,
run_id=run_id, **extra_step_kwargs):
if callback is not None and isinstance(result, PipelineIntermediateState):
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
f = GeneratorToCallbackinator(self.generate_from_embeddings, PipelineIntermediateState)
return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps,
callback=callback,
run_id=run_id, **extra_step_kwargs)
def inpaint_from_embeddings(
self,
@ -459,7 +517,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# 6. Prepare latent variables
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
# because we have our own noise function
init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
latents = self.noise_latents_for_time(init_image_latents, latent_timestep, noise_func=noise_func)
if mask.dim() == 3:
mask = mask.unsqueeze(0)
@ -491,19 +552,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
finally:
self.invokeai_diffuser.model_forward_callback = self._unet_forward
def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> (torch.FloatTensor, torch.FloatTensor):
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
# because we have our own noise function
def non_noised_latents_from_image(self, init_image, *, device, dtype):
init_image = init_image.to(device=device, dtype=dtype)
with torch.inference_mode():
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
init_latents = 0.18215 * init_latents
return init_latents
noise = noise_func(init_latents)
noised_latents = self.scheduler.add_noise(init_latents, noise, timestep)
return noised_latents, init_latents
def noise_latents_for_time(self, latents, timestep, *, noise_func):
noise = noise_func(latents)
noised_latents = self.scheduler.add_noise(latents, noise, timestep)
return noised_latents
def check_for_safety(self, output, dtype):
with torch.inference_mode():

View File

@ -29,10 +29,6 @@ class Txt2Img(Generator):
pipeline.scheduler = sampler
def make_image(x_T) -> PIL.Image.Image:
# FIXME: restore free_gpu_mem functionality
# if self.free_gpu_mem and self.model.model.device != self.model.device:
# self.model.model.to(self.model.device)
pipeline_output = pipeline.image_from_embeddings(
latents=x_T,
num_inference_steps=steps,
@ -45,10 +41,6 @@ class Txt2Img(Generator):
# TODO: threshold = threshold,
)
# FIXME: restore free_gpu_mem functionality
# if self.free_gpu_mem:
# self.model.model.to("cpu")
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image

View File

@ -3,13 +3,12 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
'''
import math
from typing import Callable, Optional
import torch
from PIL import Image
from ldm.invoke.generator.base import Generator
from ldm.invoke.generator.omnibus import Omnibus
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline
class Txt2Img2Img(Generator):
@ -17,9 +16,9 @@ class Txt2Img2Img(Generator):
super().__init__(model, precision)
self.init_latent = None # for get_noise()
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,strength,step_callback=None,**kwargs):
def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta,
conditioning, width:int, height:int, strength:float,
step_callback:Optional[Callable]=None, **kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
@ -29,125 +28,72 @@ class Txt2Img2Img(Generator):
scale_dim = min(width, height)
scale = 512 / scale_dim
init_width = math.ceil(scale * width / 64) * 64
init_height = math.ceil(scale * height / 64) * 64
init_width, init_height = trim_to_multiple_of(scale * width, scale * height)
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
@torch.no_grad()
def make_image(x_T):
shape = [
self.latent_channels,
init_height // self.downsampling_factor,
init_width // self.downsampling_factor,
]
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
pipeline_output = pipeline.latents_from_embeddings(
latents=x_T,
num_inference_steps=steps,
text_embeddings=c,
unconditioned_embeddings=uc,
guidance_scale=cfg_scale,
callback=step_callback,
extra_conditioning_info=extra_conditioning_info,
# TODO: eta = ddim_eta,
# TODO: threshold = threshold,
)
#x = self.get_noise(init_width, init_height)
x = x_T
if self.free_gpu_mem and self.model.model.device != self.model.device:
self.model.model.to(self.model.device)
samples, _ = sampler.sample(
batch_size = 1,
S = steps,
x_T = x,
conditioning = c,
shape = shape,
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
eta = ddim_eta,
img_callback = step_callback,
extra_conditioning_info = extra_conditioning_info
)
first_pass_latent_output = pipeline_output.latents
print(
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
)
# resizing
samples = torch.nn.functional.interpolate(
samples,
resized_latents = torch.nn.functional.interpolate(
first_pass_latent_output,
size=(height // self.downsampling_factor, width // self.downsampling_factor),
mode="bilinear"
)
t_enc = int(strength * steps)
ddim_sampler = DDIMSampler(self.model, device=self.model.device)
ddim_sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
z_enc = ddim_sampler.stochastic_encode(
samples,
torch.tensor([t_enc]).to(self.model.device),
noise=self.get_noise(width,height,False)
)
# decode it
samples = ddim_sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
resized_latents,
num_inference_steps=steps,
text_embeddings=c,
unconditioned_embeddings=uc,
guidance_scale=cfg_scale, strength=strength,
extra_conditioning_info=extra_conditioning_info,
all_timesteps_count=steps
)
noise_func=self.get_noise_like,
callback=step_callback)
if self.free_gpu_mem:
self.model.model.to("cpu")
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return self.sample_to_image(samples)
# FIXME: do we really need something entirely different for the inpainting model?
# in the case of the inpainting model being loaded, the trick of
# providing an interpolated latent doesn't work, so we transiently
# create a 512x512 PIL image, upscale it, and run the inpainting
# over it in img2img mode. Because the inpaing model is so conservative
# it doesn't change the image (much)
def inpaint_make_image(x_T):
omnibus = Omnibus(self.model,self.precision)
result = omnibus.generate(
prompt,
sampler=sampler,
width=init_width,
height=init_height,
step_callback=step_callback,
steps = steps,
cfg_scale = cfg_scale,
ddim_eta = ddim_eta,
conditioning = conditioning,
**kwargs
)
assert result is not None and len(result)>0,'** txt2img failed **'
image = result[0][0]
interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
print(kwargs.pop('init_image',None))
result = omnibus.generate(
prompt,
sampler=sampler,
init_image=interpolated_image,
width=width,
height=height,
seed=result[0][1],
step_callback=step_callback,
steps = steps,
cfg_scale = cfg_scale,
ddim_eta = ddim_eta,
conditioning = conditioning,
**kwargs
)
return result[0][0]
if sampler.uses_inpainting_model():
return inpaint_make_image
return make_image
def get_noise_like(self, like: torch.Tensor):
device = like.device
if device.type == 'mps':
x = torch.randn_like(like, device='cpu').to(device)
else:
return make_image
x = torch.randn_like(like, device=device)
if self.perlin > 0.0:
shape = like.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height,scale = True):
@ -175,4 +121,3 @@ class Txt2Img2Img(Generator):
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor],
device=device)