mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
diffusers integration: support img2img
This commit is contained in:
parent
f3f6213b97
commit
efbb807905
@ -3,10 +3,12 @@ import warnings
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union, Callable
|
from typing import List, Optional, Union, Callable
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
@ -210,6 +212,7 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
*,
|
*,
|
||||||
run_id: str = None,
|
run_id: str = None,
|
||||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||||
|
timesteps = None,
|
||||||
**extra_step_kwargs):
|
**extra_step_kwargs):
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||||
@ -220,16 +223,19 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
else:
|
else:
|
||||||
self.invokeai_diffuser.remove_cross_attention_control()
|
self.invokeai_diffuser.remove_cross_attention_control()
|
||||||
|
|
||||||
|
if timesteps is None:
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
latents *= self.scheduler.init_noise_sigma
|
latents *= self.scheduler.init_noise_sigma
|
||||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||||
latents=latents)
|
latents=latents)
|
||||||
|
|
||||||
batch_size = latents.shape[0]
|
batch_size = latents.shape[0]
|
||||||
batched_t = torch.full((batch_size,), self.scheduler.timesteps[0],
|
batched_t = torch.full((batch_size,), timesteps[0],
|
||||||
dtype=self.scheduler.timesteps.dtype, device=self.unet.device)
|
dtype=timesteps.dtype, device=self.unet.device)
|
||||||
# NOTE: Depends on scheduler being already initialized!
|
# NOTE: Depends on scheduler being already initialized!
|
||||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
batched_t.fill_(t)
|
batched_t.fill_(t)
|
||||||
step_output = self.step(batched_t, latents, guidance_scale,
|
step_output = self.step(batched_t, latents, guidance_scale,
|
||||||
text_embeddings, unconditioned_embeddings,
|
text_embeddings, unconditioned_embeddings,
|
||||||
@ -272,6 +278,68 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
|
return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
|
||||||
|
|
||||||
|
def img2img_from_embeddings(self,
|
||||||
|
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||||
|
strength: float,
|
||||||
|
num_inference_steps: int,
|
||||||
|
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
|
||||||
|
guidance_scale: float,
|
||||||
|
*, callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
|
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||||
|
run_id=None,
|
||||||
|
noise_func=None,
|
||||||
|
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
||||||
|
device = self.unet.device
|
||||||
|
latents_dtype = text_embeddings.dtype
|
||||||
|
batch_size = 1
|
||||||
|
num_images_per_prompt = 1
|
||||||
|
|
||||||
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
|
init_image = preprocess(init_image.convert('RGB'))
|
||||||
|
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
timesteps = self._diffusers08_get_timesteps(num_inference_steps, strength)
|
||||||
|
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||||
|
|
||||||
|
# 6. Prepare latent variables
|
||||||
|
latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
|
||||||
|
|
||||||
|
result = None
|
||||||
|
for result in self.generate_from_embeddings(
|
||||||
|
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||||
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
|
timesteps=timesteps,
|
||||||
|
run_id=run_id, **extra_step_kwargs):
|
||||||
|
if callback is not None and isinstance(result, PipelineIntermediateState):
|
||||||
|
callback(result)
|
||||||
|
if result is None:
|
||||||
|
raise AssertionError("why was that an empty generator?")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> torch.FloatTensor:
|
||||||
|
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||||
|
# because we have our own noise function
|
||||||
|
init_image = init_image.to(device=device, dtype=dtype)
|
||||||
|
with torch.inference_mode():
|
||||||
|
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||||
|
init_latents = init_latent_dist.sample() # FIXME: uses torch.randn. make reproducible!
|
||||||
|
init_latents = 0.18215 * init_latents
|
||||||
|
|
||||||
|
noise = noise_func(init_latents)
|
||||||
|
|
||||||
|
return self.scheduler.add_noise(init_latents, noise, timestep)
|
||||||
|
|
||||||
|
def _diffusers08_get_timesteps(self, num_inference_steps, strength):
|
||||||
|
# get the original timestep using init_timestep
|
||||||
|
offset = self.scheduler.config.get("steps_offset", 0)
|
||||||
|
init_timestep = int(num_inference_steps * strength) + offset
|
||||||
|
init_timestep = min(init_timestep, num_inference_steps)
|
||||||
|
|
||||||
|
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||||
|
timesteps = self.scheduler.timesteps[t_start:]
|
||||||
|
|
||||||
|
return timesteps
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def check_for_safety(self, output):
|
def check_for_safety(self, output):
|
||||||
if not getattr(self, 'feature_extractor') or not getattr(self, 'safety_checker'):
|
if not getattr(self, 'feature_extractor') or not getattr(self, 'safety_checker'):
|
||||||
|
@ -2,14 +2,10 @@
|
|||||||
ldm.invoke.generator.img2img descends from ldm.invoke.generator
|
ldm.invoke.generator.img2img descends from ldm.invoke.generator
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import PIL
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from ldm.invoke.devices import choose_autocast
|
|
||||||
from ldm.invoke.generator.base import Generator
|
from ldm.invoke.generator.base import Generator
|
||||||
|
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
|
|
||||||
class Img2Img(Generator):
|
class Img2Img(Generator):
|
||||||
@ -25,66 +21,51 @@ class Img2Img(Generator):
|
|||||||
"""
|
"""
|
||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
|
|
||||||
sampler.make_schedule(
|
|
||||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(init_image, PIL.Image.Image):
|
|
||||||
init_image = self._image_to_tensor(init_image.convert('RGB'))
|
|
||||||
|
|
||||||
scope = choose_autocast(self.precision)
|
|
||||||
with scope(self.model.device.type):
|
|
||||||
self.init_latent = self.model.get_first_stage_encoding(
|
|
||||||
self.model.encode_first_stage(init_image)
|
|
||||||
) # move to latent space
|
|
||||||
|
|
||||||
t_enc = int(strength * steps)
|
|
||||||
uc, c, extra_conditioning_info = conditioning
|
uc, c, extra_conditioning_info = conditioning
|
||||||
|
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||||
|
pipeline.scheduler = sampler
|
||||||
|
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
# encode (scaled latent)
|
# FIXME: use x_T for initial seeded noise
|
||||||
z_enc = sampler.stochastic_encode(
|
pipeline_output = pipeline.img2img_from_embeddings(
|
||||||
self.init_latent,
|
init_image, strength, steps, c, uc, cfg_scale,
|
||||||
torch.tensor([t_enc]).to(self.model.device),
|
|
||||||
noise=x_T
|
|
||||||
)
|
|
||||||
# decode it
|
|
||||||
samples = sampler.decode(
|
|
||||||
z_enc,
|
|
||||||
c,
|
|
||||||
t_enc,
|
|
||||||
img_callback = step_callback,
|
|
||||||
unconditional_guidance_scale=cfg_scale,
|
|
||||||
unconditional_conditioning=uc,
|
|
||||||
init_latent = self.init_latent, # changes how noising is performed in ksampler
|
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
all_timesteps_count = steps
|
noise_func=self.get_noise_like,
|
||||||
|
callback=step_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.sample_to_image(samples)
|
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||||
|
|
||||||
return make_image
|
return make_image
|
||||||
|
|
||||||
def get_noise(self,width,height):
|
def get_noise_like(self, like: torch.Tensor):
|
||||||
device = self.model.device
|
device = like.device
|
||||||
init_latent = self.init_latent
|
|
||||||
assert init_latent is not None,'call to get_noise() when init_latent not set'
|
|
||||||
if device.type == 'mps':
|
if device.type == 'mps':
|
||||||
x = torch.randn_like(init_latent, device='cpu').to(device)
|
x = torch.randn_like(like, device='cpu').to(device)
|
||||||
else:
|
else:
|
||||||
x = torch.randn_like(init_latent, device=device)
|
x = torch.randn_like(like, device=device)
|
||||||
if self.perlin > 0.0:
|
if self.perlin > 0.0:
|
||||||
shape = init_latent.shape
|
shape = like.shape
|
||||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
def get_noise(self,width,height):
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
# copy of the Txt2Img.get_noise
|
||||||
if len(image.shape) == 2: # 'L' image, as in a mask
|
device = self.model.device
|
||||||
image = image[None,None]
|
if self.use_mps_noise or device.type == 'mps':
|
||||||
else: # 'RGB' image
|
x = torch.randn([1,
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
self.latent_channels,
|
||||||
image = torch.from_numpy(image)
|
height // self.downsampling_factor,
|
||||||
if normalize:
|
width // self.downsampling_factor],
|
||||||
image = 2.0 * image - 1.0
|
device='cpu').to(device)
|
||||||
return image.to(self.model.device)
|
else:
|
||||||
|
x = torch.randn([1,
|
||||||
|
self.latent_channels,
|
||||||
|
height // self.downsampling_factor,
|
||||||
|
width // self.downsampling_factor],
|
||||||
|
device=device)
|
||||||
|
if self.perlin > 0.0:
|
||||||
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||||
|
return x
|
||||||
|
Loading…
Reference in New Issue
Block a user