mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
diffusers: txt2img2img (hires_fix)
with so much slicing and dicing of pipeline methods to stitch them together
This commit is contained in:
parent
bf6376417a
commit
04a5bc938e
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import secrets
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union, Callable
|
||||
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any, ParamSpec
|
||||
|
||||
import PIL.Image
|
||||
import einops
|
||||
@ -11,11 +11,11 @@ import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import attention
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import InvokeAIDiffusersCrossAttention
|
||||
from ...models.diffusion import cross_attention_control
|
||||
|
||||
# monkeypatch diffusers CrossAttention 🙈
|
||||
# this is to make prompt2prompt and (future) attention maps work
|
||||
attention.CrossAttention = InvokeAIDiffusersCrossAttention
|
||||
attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
@ -126,6 +126,10 @@ class AddsMaskGuidance:
|
||||
return masked_input
|
||||
|
||||
|
||||
def trim_to_multiple_of(*args, multiple_of=8):
|
||||
return tuple((x - x % multiple_of) for x in args)
|
||||
|
||||
|
||||
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
|
||||
"""
|
||||
|
||||
@ -133,8 +137,7 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
|
||||
:param normalize: scale the range to [-1, 1] instead of [0, 1]
|
||||
:param multiple_of: resize the input so both dimensions are a multiple of this
|
||||
"""
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % multiple_of, (w, h)) # resize to integer multiple of 8
|
||||
w, h = trim_to_multiple_of(*image.size)
|
||||
transformation = T.Compose([
|
||||
T.Resize((h, w), T.InterpolationMode.LANCZOS),
|
||||
T.ToTensor(),
|
||||
@ -148,6 +151,26 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
|
||||
def is_inpainting_model(unet: UNet2DConditionModel):
|
||||
return unet.conv_in.in_channels == 9
|
||||
|
||||
CallbackType = TypeVar('CallbackType')
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
ParamType = ParamSpec('ParamType')
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
generator_method: Callable[ParamType, ReturnType]
|
||||
callback_arg_type: Type[CallbackType]
|
||||
|
||||
def __call__(self, *args: ParamType.args,
|
||||
callback:Callable[[CallbackType], Any]=None,
|
||||
**kwargs: ParamType.kwargs) -> ReturnType:
|
||||
result = None
|
||||
for result in self.generator_method(*args, **kwargs):
|
||||
if callback is not None and isinstance(result, self.callback_arg_type):
|
||||
callback(result)
|
||||
if result is None:
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
|
||||
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
@ -250,6 +273,21 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
*, callback: Callable[[PipelineIntermediateState], None]=None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
||||
run_id=None,
|
||||
**extra_step_kwargs) -> PipelineIntermediateState:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
||||
f = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
||||
return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
**extra_step_kwargs)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
@ -303,19 +341,42 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
timesteps = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
**extra_step_kwargs):
|
||||
latents = yield from self.generate_latents_from_embeddings(latents, text_embeddings, unconditioned_embeddings,
|
||||
guidance_scale, run_id=run_id, extra_conditioning_info=extra_conditioning_info,
|
||||
timesteps=timesteps, additional_guidance=additional_guidance, **extra_step_kwargs)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(latents)
|
||||
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
||||
yield self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||
|
||||
def generate_latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
unconditioned_embeddings: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
*,
|
||||
run_id: str = None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||
timesteps = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
**extra_step_kwargs
|
||||
):
|
||||
if run_id is None:
|
||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps))
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
if timesteps is None:
|
||||
# NOTE: Depends on scheduler being already initialized!
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
@ -326,7 +387,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
batch_size = latents.shape[0]
|
||||
batched_t = torch.full((batch_size,), timesteps[0],
|
||||
dtype=timesteps.dtype, device=self.unet.device)
|
||||
# NOTE: Depends on scheduler being already initialized!
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(batched_t, latents, guidance_scale,
|
||||
@ -337,14 +398,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||
predicted_original=predicted_original)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(latents)
|
||||
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
||||
yield self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||
return latents
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
|
||||
@ -396,34 +450,38 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
||||
device = self.unet.device
|
||||
latents_dtype = self.unet.dtype
|
||||
batch_size = 1
|
||||
num_images_per_prompt = 1
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||
|
||||
if init_image.dim() == 3:
|
||||
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
|
||||
|
||||
# 6. Prepare latent variables
|
||||
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
||||
|
||||
result = self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, text_embeddings,
|
||||
unconditioned_embeddings, guidance_scale, strength,
|
||||
extra_conditioning_info, noise_func, run_id, callback,
|
||||
**extra_step_kwargs)
|
||||
return result
|
||||
|
||||
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings,
|
||||
unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info,
|
||||
noise_func, run_id=None, callback=None, **extra_step_kwargs):
|
||||
device = self.unet.device
|
||||
batch_size = initial_latents.size(0)
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size)
|
||||
latents = self.noise_latents_for_time(initial_latents, latent_timestep, noise_func=noise_func)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
latents, _ = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
|
||||
|
||||
result = None
|
||||
for result in self.generate_from_embeddings(
|
||||
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
f = GeneratorToCallbackinator(self.generate_from_embeddings, PipelineIntermediateState)
|
||||
return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
timesteps=timesteps,
|
||||
run_id=run_id, **extra_step_kwargs):
|
||||
if callback is not None and isinstance(result, PipelineIntermediateState):
|
||||
callback(result)
|
||||
if result is None:
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
callback=callback,
|
||||
run_id=run_id, **extra_step_kwargs)
|
||||
|
||||
def inpaint_from_embeddings(
|
||||
self,
|
||||
@ -459,7 +517,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
# 6. Prepare latent variables
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
|
||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||
# because we have our own noise function
|
||||
init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
||||
latents = self.noise_latents_for_time(init_image_latents, latent_timestep, noise_func=noise_func)
|
||||
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
@ -491,19 +552,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
finally:
|
||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||
|
||||
|
||||
def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> (torch.FloatTensor, torch.FloatTensor):
|
||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||
# because we have our own noise function
|
||||
def non_noised_latents_from_image(self, init_image, *, device, dtype):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
with torch.inference_mode():
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
init_latents = 0.18215 * init_latents
|
||||
return init_latents
|
||||
|
||||
noise = noise_func(init_latents)
|
||||
noised_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
return noised_latents, init_latents
|
||||
def noise_latents_for_time(self, latents, timestep, *, noise_func):
|
||||
noise = noise_func(latents)
|
||||
noised_latents = self.scheduler.add_noise(latents, noise, timestep)
|
||||
return noised_latents
|
||||
|
||||
def check_for_safety(self, output, dtype):
|
||||
with torch.inference_mode():
|
||||
|
@ -29,10 +29,6 @@ class Txt2Img(Generator):
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
def make_image(x_T) -> PIL.Image.Image:
|
||||
# FIXME: restore free_gpu_mem functionality
|
||||
# if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
# self.model.model.to(self.model.device)
|
||||
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
latents=x_T,
|
||||
num_inference_steps=steps,
|
||||
@ -45,10 +41,6 @@ class Txt2Img(Generator):
|
||||
# TODO: threshold = threshold,
|
||||
)
|
||||
|
||||
# FIXME: restore free_gpu_mem functionality
|
||||
# if self.free_gpu_mem:
|
||||
# self.model.model.to("cpu")
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
|
@ -3,13 +3,12 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.invoke.generator.omnibus import Omnibus
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline
|
||||
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
@ -17,9 +16,9 @@ class Txt2Img2Img(Generator):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # for get_noise()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,width,height,strength,step_callback=None,**kwargs):
|
||||
def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta,
|
||||
conditioning, width:int, height:int, strength:float,
|
||||
step_callback:Optional[Callable]=None, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
@ -29,126 +28,73 @@ class Txt2Img2Img(Generator):
|
||||
scale_dim = min(width, height)
|
||||
scale = 512 / scale_dim
|
||||
|
||||
init_width = math.ceil(scale * width / 64) * 64
|
||||
init_height = math.ceil(scale * height / 64) * 64
|
||||
init_width, init_height = trim_to_multiple_of(scale * width, scale * height)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
|
||||
shape = [
|
||||
self.latent_channels,
|
||||
init_height // self.downsampling_factor,
|
||||
init_width // self.downsampling_factor,
|
||||
]
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
pipeline_output = pipeline.latents_from_embeddings(
|
||||
latents=x_T,
|
||||
num_inference_steps=steps,
|
||||
text_embeddings=c,
|
||||
unconditioned_embeddings=uc,
|
||||
guidance_scale=cfg_scale,
|
||||
callback=step_callback,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
# TODO: eta = ddim_eta,
|
||||
# TODO: threshold = threshold,
|
||||
)
|
||||
|
||||
#x = self.get_noise(init_width, init_height)
|
||||
x = x_T
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
batch_size = 1,
|
||||
S = steps,
|
||||
x_T = x,
|
||||
conditioning = c,
|
||||
shape = shape,
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback,
|
||||
extra_conditioning_info = extra_conditioning_info
|
||||
)
|
||||
first_pass_latent_output = pipeline_output.latents
|
||||
|
||||
print(
|
||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
|
||||
# resizing
|
||||
samples = torch.nn.functional.interpolate(
|
||||
samples,
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
first_pass_latent_output,
|
||||
size=(height // self.downsampling_factor, width // self.downsampling_factor),
|
||||
mode="bilinear"
|
||||
)
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
ddim_sampler = DDIMSampler(self.model, device=self.model.device)
|
||||
ddim_sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
z_enc = ddim_sampler.stochastic_encode(
|
||||
samples,
|
||||
torch.tensor([t_enc]).to(self.model.device),
|
||||
noise=self.get_noise(width,height,False)
|
||||
)
|
||||
|
||||
# decode it
|
||||
samples = ddim_sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
|
||||
resized_latents,
|
||||
num_inference_steps=steps,
|
||||
text_embeddings=c,
|
||||
unconditioned_embeddings=uc,
|
||||
guidance_scale=cfg_scale, strength=strength,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
all_timesteps_count=steps
|
||||
)
|
||||
noise_func=self.get_noise_like,
|
||||
callback=step_callback)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
self.model.model.to("cpu")
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
# FIXME: do we really need something entirely different for the inpainting model?
|
||||
|
||||
# in the case of the inpainting model being loaded, the trick of
|
||||
# providing an interpolated latent doesn't work, so we transiently
|
||||
# create a 512x512 PIL image, upscale it, and run the inpainting
|
||||
# over it in img2img mode. Because the inpaing model is so conservative
|
||||
# it doesn't change the image (much)
|
||||
def inpaint_make_image(x_T):
|
||||
omnibus = Omnibus(self.model,self.precision)
|
||||
result = omnibus.generate(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
width=init_width,
|
||||
height=init_height,
|
||||
step_callback=step_callback,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
ddim_eta = ddim_eta,
|
||||
conditioning = conditioning,
|
||||
**kwargs
|
||||
)
|
||||
assert result is not None and len(result)>0,'** txt2img failed **'
|
||||
image = result[0][0]
|
||||
interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
|
||||
print(kwargs.pop('init_image',None))
|
||||
result = omnibus.generate(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
init_image=interpolated_image,
|
||||
width=width,
|
||||
height=height,
|
||||
seed=result[0][1],
|
||||
step_callback=step_callback,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
ddim_eta = ddim_eta,
|
||||
conditioning = conditioning,
|
||||
**kwargs
|
||||
)
|
||||
return result[0][0]
|
||||
|
||||
if sampler.uses_inpainting_model():
|
||||
return inpaint_make_image
|
||||
else:
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
if device.type == 'mps':
|
||||
x = torch.randn_like(like, device='cpu').to(device)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = like.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height,scale = True):
|
||||
# print(f"Get noise: {width}x{height}")
|
||||
@ -175,4 +121,3 @@ class Txt2Img2Img(Generator):
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor],
|
||||
device=device)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user