dev: upgrade to diffusers 0.8 (from 0.7.1)

We get to remove some code by using methods that were factored out in the base class.
This commit is contained in:
Kevin Turner 2022-11-23 14:46:41 -08:00
parent efbb807905
commit ceb53ccdfb
6 changed files with 28 additions and 124 deletions

View File

@ -11,7 +11,7 @@ dependencies:
- --extra-index-url https://download.pytorch.org/whl/rocm5.2/ - --extra-index-url https://download.pytorch.org/whl/rocm5.2/
- albumentations==0.4.3 - albumentations==0.4.3
- dependency_injector==4.40.0 - dependency_injector==4.40.0
- diffusers==0.6.0 - diffusers~=0.8
- einops==0.3.0 - einops==0.3.0
- eventlet - eventlet
- flask==2.1.3 - flask==2.1.3

View File

@ -15,7 +15,7 @@ dependencies:
- accelerate~=0.13 - accelerate~=0.13
- albumentations==0.4.3 - albumentations==0.4.3
- dependency_injector==4.40.0 - dependency_injector==4.40.0
- diffusers==0.6.0 - diffusers~=0.8
- einops==0.3.0 - einops==0.3.0
- eventlet - eventlet
- flask==2.1.3 - flask==2.1.3

View File

@ -22,7 +22,7 @@ dependencies:
- albumentations=1.2 - albumentations=1.2
- coloredlogs=15.0 - coloredlogs=15.0
- diffusers~=0.7 - diffusers~=0.8
- einops=0.3 - einops=0.3
- eventlet - eventlet
- grpcio=1.46 - grpcio=1.46

View File

@ -15,7 +15,7 @@ dependencies:
- albumentations==0.4.3 - albumentations==0.4.3
- basicsr==1.4.1 - basicsr==1.4.1
- dependency_injector==4.40.0 - dependency_injector==4.40.0
- diffusers==0.6.0 - diffusers~=0.8
- einops==0.3.0 - einops==0.3.0
- eventlet - eventlet
- flask==2.1.3 - flask==2.1.3

View File

@ -1,7 +1,7 @@
# pip will resolve the version which matches torch # pip will resolve the version which matches torch
albumentations albumentations
dependency_injector==4.40.0 dependency_injector==4.40.0
diffusers[torch]~=0.7 diffusers[torch]~=0.8
einops einops
eventlet eventlet
facexlib facexlib

View File

@ -5,8 +5,8 @@ from typing import List, Optional, Union, Callable
import PIL.Image import PIL.Image
import torch import torch
from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@ -26,7 +26,7 @@ class PipelineIntermediateState:
predicted_original: Optional[torch.Tensor] = None predicted_original: Optional[torch.Tensor] = None
class StableDiffusionGeneratorPipeline(DiffusionPipeline): class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion. Pipeline for text-to-image generation using Stable Diffusion.
@ -67,10 +67,10 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker, safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: CLIPFeatureExtractor, feature_extractor: Optional[CLIPFeatureExtractor],
): ):
super().__init__() super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
@ -88,51 +88,6 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
) )
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
guidance_scale: float, guidance_scale: float,
@ -195,10 +150,17 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings, unconditioned_embeddings = self.get_text_embeddings(prompt, opposing_prompt, do_classifier_free_guidance, batch_size)\
.to(self.unet.device) combined_embeddings = self._encode_prompt(prompt, device=self._execution_device, num_images_per_prompt=1,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=opposing_prompt)
text_embeddings, unconditioned_embeddings = combined_embeddings.chunk(2)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
latents = self.prepare_latents(latents, batch_size, height, width, generator, self.unet.dtype) latents = self.prepare_latents(batch_size=batch_size, num_channels_latents=self.unet.in_channels,
height=height, width=width,
dtype=self.unet.dtype, device=self._execution_device,
generator=generator,
latents=latents)
yield from self.generate_from_embeddings(latents, text_embeddings, unconditioned_embeddings, yield from self.generate_from_embeddings(latents, text_embeddings, unconditioned_embeddings,
guidance_scale, run_id=run_id, **extra_step_kwargs) guidance_scale, run_id=run_id, **extra_step_kwargs)
@ -248,9 +210,10 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
image = self.decode_to_image(latents) with torch.inference_mode():
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) image = self.decode_latents(latents)
yield self.check_for_safety(output) output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
yield self.check_for_safety(output, dtype=text_embeddings.dtype)
@torch.inference_mode() @torch.inference_mode()
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float, def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
@ -340,46 +303,12 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
return timesteps return timesteps
@torch.inference_mode() def check_for_safety(self, output, dtype):
def check_for_safety(self, output): with torch.inference_mode():
if not getattr(self, 'feature_extractor') or not getattr(self, 'safety_checker'): screened_images, has_nsfw_concept = self.run_safety_checker(
return output output.images, device=self._execution_device, dtype=dtype)
images = output.images
safety_checker_output = self.feature_extractor(self.numpy_to_pil(images),
return_tensors="pt").to(self.device)
screened_images, has_nsfw_concept = self.safety_checker(
images=images, clip_input=safety_checker_output.pixel_values)
return StableDiffusionPipelineOutput(screened_images, has_nsfw_concept) return StableDiffusionPipelineOutput(screened_images, has_nsfw_concept)
@torch.inference_mode()
def decode_to_image(self, latents):
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return image
@torch.inference_mode()
def get_text_embeddings(self,
prompt: Union[str, List[str]],
opposing_prompt: Union[str, List[str]],
do_classifier_free_guidance: bool,
batch_size: int):
# get prompt text embeddings
text_input = self._tokenize(prompt)
text_embeddings = self.text_encoder(text_input.input_ids)[0]
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
# opposing prompt defaults to blank caption for everything in the batch
text_anti_input = self._tokenize(opposing_prompt or [""] * batch_size)
uncond_embeddings = self.text_encoder(text_anti_input.input_ids)[0]
else:
uncond_embeddings = None
return text_embeddings, uncond_embeddings
@torch.inference_mode() @torch.inference_mode()
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None): def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
""" """
@ -406,28 +335,3 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
def channels(self) -> int: def channels(self) -> int:
"""Compatible with DiffusionWrapper""" """Compatible with DiffusionWrapper"""
return self.unet.in_channels return self.unet.in_channels
def prepare_latents(self, latents, batch_size, height, width, generator, dtype):
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if latents is None:
latents = torch.randn(
latents_shape,
generator=generator,
device=self.unet.device,
dtype=dtype
)
else:
if latents.shape != latents_shape:
raise ValueError(
f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
if latents.device != self.unet.device:
raise ValueError(f"Unexpected latents device, got {latents.device}, "
f"expected {self.unet.device}")
# scale the initial noise by the standard deviation required by the scheduler
latents *= self.scheduler.init_noise_sigma
return latents