diffusers integration: support img2img

This commit is contained in:
Kevin Turner 2022-11-23 17:38:31 -08:00
parent f3f6213b97
commit efbb807905
2 changed files with 106 additions and 57 deletions

View File

@ -3,10 +3,12 @@ import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union, Callable from typing import List, Optional, Union, Callable
import PIL.Image
import torch import torch
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@ -210,6 +212,7 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
*, *,
run_id: str = None, run_id: str = None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
timesteps = None,
**extra_step_kwargs): **extra_step_kwargs):
if run_id is None: if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH) run_id = secrets.token_urlsafe(self.ID_LENGTH)
@ -220,16 +223,19 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
else: else:
self.invokeai_diffuser.remove_cross_attention_control() self.invokeai_diffuser.remove_cross_attention_control()
if timesteps is None:
timesteps = self.scheduler.timesteps
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
latents *= self.scheduler.init_noise_sigma latents *= self.scheduler.init_noise_sigma
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
latents=latents) latents=latents)
batch_size = latents.shape[0] batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), self.scheduler.timesteps[0], batched_t = torch.full((batch_size,), timesteps[0],
dtype=self.scheduler.timesteps.dtype, device=self.unet.device) dtype=timesteps.dtype, device=self.unet.device)
# NOTE: Depends on scheduler being already initialized! # NOTE: Depends on scheduler being already initialized!
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t) batched_t.fill_(t)
step_output = self.step(batched_t, latents, guidance_scale, step_output = self.step(batched_t, latents, guidance_scale,
text_embeddings, unconditioned_embeddings, text_embeddings, unconditioned_embeddings,
@ -272,6 +278,68 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
# predict the noise residual # predict the noise residual
return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
def img2img_from_embeddings(self,
init_image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float,
num_inference_steps: int,
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
guidance_scale: float,
*, callback: Callable[[PipelineIntermediateState], None] = None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
run_id=None,
noise_func=None,
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
device = self.unet.device
latents_dtype = text_embeddings.dtype
batch_size = 1
num_images_per_prompt = 1
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image.convert('RGB'))
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self._diffusers08_get_timesteps(num_inference_steps, strength)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
result = None
for result in self.generate_from_embeddings(
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps,
run_id=run_id, **extra_step_kwargs):
if callback is not None and isinstance(result, PipelineIntermediateState):
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> torch.FloatTensor:
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
# because we have our own noise function
init_image = init_image.to(device=device, dtype=dtype)
with torch.inference_mode():
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample() # FIXME: uses torch.randn. make reproducible!
init_latents = 0.18215 * init_latents
noise = noise_func(init_latents)
return self.scheduler.add_noise(init_latents, noise, timestep)
def _diffusers08_get_timesteps(self, num_inference_steps, strength):
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps
@torch.inference_mode() @torch.inference_mode()
def check_for_safety(self, output): def check_for_safety(self, output):
if not getattr(self, 'feature_extractor') or not getattr(self, 'safety_checker'): if not getattr(self, 'feature_extractor') or not getattr(self, 'safety_checker'):

View File

@ -2,14 +2,10 @@
ldm.invoke.generator.img2img descends from ldm.invoke.generator ldm.invoke.generator.img2img descends from ldm.invoke.generator
''' '''
import PIL
import numpy as np
import torch import torch
from PIL import Image
from torch import Tensor
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import Generator from ldm.invoke.generator.base import Generator
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
class Img2Img(Generator): class Img2Img(Generator):
@ -25,66 +21,51 @@ class Img2Img(Generator):
""" """
self.perlin = perlin self.perlin = perlin
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
if isinstance(init_image, PIL.Image.Image):
init_image = self._image_to_tensor(init_image.convert('RGB'))
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
t_enc = int(strength * steps)
uc, c, extra_conditioning_info = conditioning uc, c, extra_conditioning_info = conditioning
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
def make_image(x_T): def make_image(x_T):
# encode (scaled latent) # FIXME: use x_T for initial seeded noise
z_enc = sampler.stochastic_encode( pipeline_output = pipeline.img2img_from_embeddings(
self.init_latent, init_image, strength, steps, c, uc, cfg_scale,
torch.tensor([t_enc]).to(self.model.device),
noise=x_T
)
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
init_latent = self.init_latent, # changes how noising is performed in ksampler
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=extra_conditioning_info,
all_timesteps_count = steps noise_func=self.get_noise_like,
callback=step_callback
) )
return self.sample_to_image(samples) return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image return make_image
def get_noise(self,width,height): def get_noise_like(self, like: torch.Tensor):
device = self.model.device device = like.device
init_latent = self.init_latent
assert init_latent is not None,'call to get_noise() when init_latent not set'
if device.type == 'mps': if device.type == 'mps':
x = torch.randn_like(init_latent, device='cpu').to(device) x = torch.randn_like(like, device='cpu').to(device)
else: else:
x = torch.randn_like(init_latent, device=device) x = torch.randn_like(like, device=device)
if self.perlin > 0.0: if self.perlin > 0.0:
shape = init_latent.shape shape = like.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x return x
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor: def get_noise(self,width,height):
image = np.array(image).astype(np.float32) / 255.0 # copy of the Txt2Img.get_noise
if len(image.shape) == 2: # 'L' image, as in a mask device = self.model.device
image = image[None,None] if self.use_mps_noise or device.type == 'mps':
else: # 'RGB' image x = torch.randn([1,
image = image[None].transpose(0, 3, 1, 2) self.latent_channels,
image = torch.from_numpy(image) height // self.downsampling_factor,
if normalize: width // self.downsampling_factor],
image = 2.0 * image - 1.0 device='cpu').to(device)
return image.to(self.model.device) else:
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=device)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x