diffusers integration: support img2img

This commit is contained in:
Kevin Turner 2022-11-23 17:38:31 -08:00
parent f3f6213b97
commit efbb807905
2 changed files with 106 additions and 57 deletions

View File

@ -3,10 +3,12 @@ import warnings
from dataclasses import dataclass
from typing import List, Optional, Union, Callable
import PIL.Image
import torch
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@ -210,6 +212,7 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
*,
run_id: str = None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
timesteps = None,
**extra_step_kwargs):
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
@ -220,16 +223,19 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
else:
self.invokeai_diffuser.remove_cross_attention_control()
if timesteps is None:
timesteps = self.scheduler.timesteps
# scale the initial noise by the standard deviation required by the scheduler
latents *= self.scheduler.init_noise_sigma
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
latents=latents)
batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), self.scheduler.timesteps[0],
dtype=self.scheduler.timesteps.dtype, device=self.unet.device)
batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device)
# NOTE: Depends on scheduler being already initialized!
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(batched_t, latents, guidance_scale,
text_embeddings, unconditioned_embeddings,
@ -272,6 +278,68 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
# predict the noise residual
return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
def img2img_from_embeddings(self,
init_image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float,
num_inference_steps: int,
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
guidance_scale: float,
*, callback: Callable[[PipelineIntermediateState], None] = None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
run_id=None,
noise_func=None,
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
device = self.unet.device
latents_dtype = text_embeddings.dtype
batch_size = 1
num_images_per_prompt = 1
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image.convert('RGB'))
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self._diffusers08_get_timesteps(num_inference_steps, strength)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
result = None
for result in self.generate_from_embeddings(
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps,
run_id=run_id, **extra_step_kwargs):
if callback is not None and isinstance(result, PipelineIntermediateState):
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> torch.FloatTensor:
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
# because we have our own noise function
init_image = init_image.to(device=device, dtype=dtype)
with torch.inference_mode():
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample() # FIXME: uses torch.randn. make reproducible!
init_latents = 0.18215 * init_latents
noise = noise_func(init_latents)
return self.scheduler.add_noise(init_latents, noise, timestep)
def _diffusers08_get_timesteps(self, num_inference_steps, strength):
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps
@torch.inference_mode()
def check_for_safety(self, output):
if not getattr(self, 'feature_extractor') or not getattr(self, 'safety_checker'):

View File

@ -2,14 +2,10 @@
ldm.invoke.generator.img2img descends from ldm.invoke.generator
'''
import PIL
import numpy as np
import torch
from PIL import Image
from torch import Tensor
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import Generator
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
class Img2Img(Generator):
@ -25,66 +21,51 @@ class Img2Img(Generator):
"""
self.perlin = perlin
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
if isinstance(init_image, PIL.Image.Image):
init_image = self._image_to_tensor(init_image.convert('RGB'))
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
t_enc = int(strength * steps)
uc, c, extra_conditioning_info = conditioning
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
def make_image(x_T):
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
self.init_latent,
torch.tensor([t_enc]).to(self.model.device),
noise=x_T
)
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
init_latent = self.init_latent, # changes how noising is performed in ksampler
extra_conditioning_info = extra_conditioning_info,
all_timesteps_count = steps
# FIXME: use x_T for initial seeded noise
pipeline_output = pipeline.img2img_from_embeddings(
init_image, strength, steps, c, uc, cfg_scale,
extra_conditioning_info=extra_conditioning_info,
noise_func=self.get_noise_like,
callback=step_callback
)
return self.sample_to_image(samples)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image
def get_noise(self,width,height):
device = self.model.device
init_latent = self.init_latent
assert init_latent is not None,'call to get_noise() when init_latent not set'
def get_noise_like(self, like: torch.Tensor):
device = like.device
if device.type == 'mps':
x = torch.randn_like(init_latent, device='cpu').to(device)
x = torch.randn_like(like, device='cpu').to(device)
else:
x = torch.randn_like(init_latent, device=device)
x = torch.randn_like(like, device=device)
if self.perlin > 0.0:
shape = init_latent.shape
shape = like.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
image = np.array(image).astype(np.float32) / 255.0
if len(image.shape) == 2: # 'L' image, as in a mask
image = image[None,None]
else: # 'RGB' image
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
if normalize:
image = 2.0 * image - 1.0
return image.to(self.model.device)
def get_noise(self,width,height):
# copy of the Txt2Img.get_noise
device = self.model.device
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(device)
else:
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=device)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x