|
|
|
@ -5,8 +5,8 @@ from typing import List, Optional, Union, Callable
|
|
|
|
|
|
|
|
|
|
import PIL.Image
|
|
|
|
|
import torch
|
|
|
|
|
from diffusers import StableDiffusionPipeline
|
|
|
|
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
|
|
|
|
from diffusers.pipeline_utils import DiffusionPipeline
|
|
|
|
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|
|
|
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess
|
|
|
|
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
|
|
@ -26,7 +26,7 @@ class PipelineIntermediateState:
|
|
|
|
|
predicted_original: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|
|
|
|
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|
|
|
|
r"""
|
|
|
|
|
Pipeline for text-to-image generation using Stable Diffusion.
|
|
|
|
|
|
|
|
|
@ -67,10 +67,10 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|
|
|
|
tokenizer: CLIPTokenizer,
|
|
|
|
|
unet: UNet2DConditionModel,
|
|
|
|
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
|
|
|
|
safety_checker: StableDiffusionSafetyChecker,
|
|
|
|
|
feature_extractor: CLIPFeatureExtractor,
|
|
|
|
|
safety_checker: Optional[StableDiffusionSafetyChecker],
|
|
|
|
|
feature_extractor: Optional[CLIPFeatureExtractor],
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor)
|
|
|
|
|
|
|
|
|
|
self.register_modules(
|
|
|
|
|
vae=vae,
|
|
|
|
@ -88,51 +88,6 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|
|
|
|
)
|
|
|
|
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
|
|
|
|
|
|
|
|
|
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
|
|
|
|
r"""
|
|
|
|
|
Enable sliced attention computation.
|
|
|
|
|
|
|
|
|
|
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
|
|
|
|
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
|
|
|
|
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
|
|
|
|
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
|
|
|
|
`attention_head_dim` must be a multiple of `slice_size`.
|
|
|
|
|
"""
|
|
|
|
|
if slice_size == "auto":
|
|
|
|
|
# half the attention head size is usually a good trade-off between
|
|
|
|
|
# speed and memory
|
|
|
|
|
slice_size = self.unet.config.attention_head_dim // 2
|
|
|
|
|
self.unet.set_attention_slice(slice_size)
|
|
|
|
|
|
|
|
|
|
def disable_attention_slicing(self):
|
|
|
|
|
r"""
|
|
|
|
|
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
|
|
|
|
back to computing attention in one step.
|
|
|
|
|
"""
|
|
|
|
|
# set slice_size = `None` to disable `attention slicing`
|
|
|
|
|
self.enable_attention_slicing(None)
|
|
|
|
|
|
|
|
|
|
def enable_xformers_memory_efficient_attention(self):
|
|
|
|
|
r"""
|
|
|
|
|
Enable memory efficient attention as implemented in xformers.
|
|
|
|
|
|
|
|
|
|
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
|
|
|
|
time. Speed up at training time is not guaranteed.
|
|
|
|
|
|
|
|
|
|
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
|
|
|
|
is used.
|
|
|
|
|
"""
|
|
|
|
|
self.unet.set_use_memory_efficient_attention_xformers(True)
|
|
|
|
|
|
|
|
|
|
def disable_xformers_memory_efficient_attention(self):
|
|
|
|
|
r"""
|
|
|
|
|
Disable memory efficient attention as implemented in xformers.
|
|
|
|
|
"""
|
|
|
|
|
self.unet.set_use_memory_efficient_attention_xformers(False)
|
|
|
|
|
|
|
|
|
|
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
|
|
|
|
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
|
|
|
|
|
guidance_scale: float,
|
|
|
|
@ -195,10 +150,17 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|
|
|
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
|
|
|
|
# corresponds to doing no classifier free guidance.
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0
|
|
|
|
|
text_embeddings, unconditioned_embeddings = self.get_text_embeddings(prompt, opposing_prompt, do_classifier_free_guidance, batch_size)\
|
|
|
|
|
.to(self.unet.device)
|
|
|
|
|
|
|
|
|
|
combined_embeddings = self._encode_prompt(prompt, device=self._execution_device, num_images_per_prompt=1,
|
|
|
|
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
|
|
|
negative_prompt=opposing_prompt)
|
|
|
|
|
text_embeddings, unconditioned_embeddings = combined_embeddings.chunk(2)
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps)
|
|
|
|
|
latents = self.prepare_latents(latents, batch_size, height, width, generator, self.unet.dtype)
|
|
|
|
|
latents = self.prepare_latents(batch_size=batch_size, num_channels_latents=self.unet.in_channels,
|
|
|
|
|
height=height, width=width,
|
|
|
|
|
dtype=self.unet.dtype, device=self._execution_device,
|
|
|
|
|
generator=generator,
|
|
|
|
|
latents=latents)
|
|
|
|
|
|
|
|
|
|
yield from self.generate_from_embeddings(latents, text_embeddings, unconditioned_embeddings,
|
|
|
|
|
guidance_scale, run_id=run_id, **extra_step_kwargs)
|
|
|
|
@ -248,9 +210,10 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|
|
|
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
image = self.decode_to_image(latents)
|
|
|
|
|
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
|
|
|
|
yield self.check_for_safety(output)
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
image = self.decode_latents(latents)
|
|
|
|
|
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
|
|
|
|
yield self.check_for_safety(output, dtype=text_embeddings.dtype)
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
|
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
|
|
|
|
@ -340,46 +303,12 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|
|
|
|
|
|
|
|
|
return timesteps
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
|
def check_for_safety(self, output):
|
|
|
|
|
if not getattr(self, 'feature_extractor') or not getattr(self, 'safety_checker'):
|
|
|
|
|
return output
|
|
|
|
|
images = output.images
|
|
|
|
|
safety_checker_output = self.feature_extractor(self.numpy_to_pil(images),
|
|
|
|
|
return_tensors="pt").to(self.device)
|
|
|
|
|
screened_images, has_nsfw_concept = self.safety_checker(
|
|
|
|
|
images=images, clip_input=safety_checker_output.pixel_values)
|
|
|
|
|
def check_for_safety(self, output, dtype):
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
screened_images, has_nsfw_concept = self.run_safety_checker(
|
|
|
|
|
output.images, device=self._execution_device, dtype=dtype)
|
|
|
|
|
return StableDiffusionPipelineOutput(screened_images, has_nsfw_concept)
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
|
def decode_to_image(self, latents):
|
|
|
|
|
# scale and decode the image latents with vae
|
|
|
|
|
latents = 1 / 0.18215 * latents
|
|
|
|
|
image = self.vae.decode(latents).sample
|
|
|
|
|
image = (image / 2 + 0.5).clamp(0, 1)
|
|
|
|
|
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
|
def get_text_embeddings(self,
|
|
|
|
|
prompt: Union[str, List[str]],
|
|
|
|
|
opposing_prompt: Union[str, List[str]],
|
|
|
|
|
do_classifier_free_guidance: bool,
|
|
|
|
|
batch_size: int):
|
|
|
|
|
# get prompt text embeddings
|
|
|
|
|
text_input = self._tokenize(prompt)
|
|
|
|
|
|
|
|
|
|
text_embeddings = self.text_encoder(text_input.input_ids)[0]
|
|
|
|
|
# get unconditional embeddings for classifier free guidance
|
|
|
|
|
if do_classifier_free_guidance:
|
|
|
|
|
# opposing prompt defaults to blank caption for everything in the batch
|
|
|
|
|
text_anti_input = self._tokenize(opposing_prompt or [""] * batch_size)
|
|
|
|
|
uncond_embeddings = self.text_encoder(text_anti_input.input_ids)[0]
|
|
|
|
|
else:
|
|
|
|
|
uncond_embeddings = None
|
|
|
|
|
|
|
|
|
|
return text_embeddings, uncond_embeddings
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
|
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
|
|
|
|
"""
|
|
|
|
@ -406,28 +335,3 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|
|
|
|
def channels(self) -> int:
|
|
|
|
|
"""Compatible with DiffusionWrapper"""
|
|
|
|
|
return self.unet.in_channels
|
|
|
|
|
|
|
|
|
|
def prepare_latents(self, latents, batch_size, height, width, generator, dtype):
|
|
|
|
|
# get the initial random noise unless the user supplied it
|
|
|
|
|
# Unlike in other pipelines, latents need to be generated in the target device
|
|
|
|
|
# for 1-to-1 results reproducibility with the CompVis implementation.
|
|
|
|
|
# However this currently doesn't work in `mps`.
|
|
|
|
|
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
|
|
|
|
if latents is None:
|
|
|
|
|
latents = torch.randn(
|
|
|
|
|
latents_shape,
|
|
|
|
|
generator=generator,
|
|
|
|
|
device=self.unet.device,
|
|
|
|
|
dtype=dtype
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
if latents.shape != latents_shape:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
|
|
|
|
if latents.device != self.unet.device:
|
|
|
|
|
raise ValueError(f"Unexpected latents device, got {latents.device}, "
|
|
|
|
|
f"expected {self.unet.device}")
|
|
|
|
|
|
|
|
|
|
# scale the initial noise by the standard deviation required by the scheduler
|
|
|
|
|
latents *= self.scheduler.init_noise_sigma
|
|
|
|
|
return latents
|
|
|
|
|