mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
diffusers integration: support img2img
This commit is contained in:
parent
f3f6213b97
commit
efbb807905
@ -3,10 +3,12 @@ import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union, Callable
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
@ -210,6 +212,7 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
||||
*,
|
||||
run_id: str = None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||
timesteps = None,
|
||||
**extra_step_kwargs):
|
||||
if run_id is None:
|
||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||
@ -220,16 +223,19 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents *= self.scheduler.init_noise_sigma
|
||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||
latents=latents)
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
batched_t = torch.full((batch_size,), self.scheduler.timesteps[0],
|
||||
dtype=self.scheduler.timesteps.dtype, device=self.unet.device)
|
||||
batched_t = torch.full((batch_size,), timesteps[0],
|
||||
dtype=timesteps.dtype, device=self.unet.device)
|
||||
# NOTE: Depends on scheduler being already initialized!
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(batched_t, latents, guidance_scale,
|
||||
text_embeddings, unconditioned_embeddings,
|
||||
@ -272,6 +278,68 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
||||
# predict the noise residual
|
||||
return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
def img2img_from_embeddings(self,
|
||||
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
*, callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||
run_id=None,
|
||||
noise_func=None,
|
||||
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
||||
device = self.unet.device
|
||||
latents_dtype = text_embeddings.dtype
|
||||
batch_size = 1
|
||||
num_images_per_prompt = 1
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = preprocess(init_image.convert('RGB'))
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self._diffusers08_get_timesteps(num_inference_steps, strength)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
|
||||
|
||||
result = None
|
||||
for result in self.generate_from_embeddings(
|
||||
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
timesteps=timesteps,
|
||||
run_id=run_id, **extra_step_kwargs):
|
||||
if callback is not None and isinstance(result, PipelineIntermediateState):
|
||||
callback(result)
|
||||
if result is None:
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> torch.FloatTensor:
|
||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||
# because we have our own noise function
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
with torch.inference_mode():
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample() # FIXME: uses torch.randn. make reproducible!
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
noise = noise_func(init_latents)
|
||||
|
||||
return self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
|
||||
def _diffusers08_get_timesteps(self, num_inference_steps, strength):
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps
|
||||
|
||||
@torch.inference_mode()
|
||||
def check_for_safety(self, output):
|
||||
if not getattr(self, 'feature_extractor') or not getattr(self, 'safety_checker'):
|
||||
|
@ -2,14 +2,10 @@
|
||||
ldm.invoke.generator.img2img descends from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
|
||||
|
||||
class Img2Img(Generator):
|
||||
@ -25,66 +21,51 @@ class Img2Img(Generator):
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = self._image_to_tensor(init_image.convert('RGB'))
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
self.model.encode_first_stage(init_image)
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
self.init_latent,
|
||||
torch.tensor([t_enc]).to(self.model.device),
|
||||
noise=x_T
|
||||
)
|
||||
# decode it
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
init_latent = self.init_latent, # changes how noising is performed in ksampler
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
all_timesteps_count = steps
|
||||
# FIXME: use x_T for initial seeded noise
|
||||
pipeline_output = pipeline.img2img_from_embeddings(
|
||||
init_image, strength, steps, c, uc, cfg_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
noise_func=self.get_noise_like,
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
init_latent = self.init_latent
|
||||
assert init_latent is not None,'call to get_noise() when init_latent not set'
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
if device.type == 'mps':
|
||||
x = torch.randn_like(init_latent, device='cpu').to(device)
|
||||
x = torch.randn_like(like, device='cpu').to(device)
|
||||
else:
|
||||
x = torch.randn_like(init_latent, device=device)
|
||||
x = torch.randn_like(like, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = init_latent.shape
|
||||
shape = like.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
|
||||
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
if len(image.shape) == 2: # 'L' image, as in a mask
|
||||
image = image[None,None]
|
||||
else: # 'RGB' image
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
if normalize:
|
||||
image = 2.0 * image - 1.0
|
||||
return image.to(self.model.device)
|
||||
def get_noise(self,width,height):
|
||||
# copy of the Txt2Img.get_noise
|
||||
device = self.model.device
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
x = torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device='cpu').to(device)
|
||||
else:
|
||||
x = torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device=device)
|
||||
if self.perlin > 0.0:
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||
return x
|
||||
|
Loading…
Reference in New Issue
Block a user