Remove requirements to diffusers pipeline, add support for torch-sdp, apply attention to controlnet models too

This commit is contained in:
Sergey Borisov 2023-09-01 00:18:31 +03:00
parent a74e2108bb
commit 6bb657b3f3
2 changed files with 49 additions and 123 deletions

View File

@ -286,30 +286,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
unet,
scheduler,
) -> StableDiffusionGeneratorPipeline:
# TODO:
# configure_model_padding(
# unet,
# self.seamless,
# self.seamless_axes,
# )
class FakeVae:
class FakeVaeConfig:
def __init__(self):
self.block_out_channels = [0]
def __init__(self):
self.config = FakeVae.FakeVaeConfig()
return StableDiffusionGeneratorPipeline(
vae=FakeVae(), # TODO: oh...
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
def prep_control_data(

View File

@ -12,19 +12,12 @@ import torch
import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.models.attention_processor import AttnProcessor2_0
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput
from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config import InvokeAIAppConfig
from .diffusion import (
@ -34,6 +27,7 @@ from .diffusion import (
BasicConditioningInfo,
)
from ..util import normalize_device, auto_detect_slice_size
from tqdm.auto import tqdm
@dataclass
@ -205,94 +199,43 @@ class ConditioningData:
return dataclasses.replace(self, scheduler_args=scheduler_args)
@dataclass
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r"""
Output class for InvokeAI's Stable Diffusion pipeline.
Args:
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
after generation completes. Optional.
"""
attention_map_saver: Optional[AttentionMapSaver]
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline.
Hopefully future versions of diffusers provide access to more of these functions so that we don't
need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
class StableDiffusionGeneratorPipeline:
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
control_model: ControlNetModel = None,
):
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
requires_safety_checker,
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
# FIXME: can't currently register control module
# control_model=control_model,
)
self.unet = unet
self.scheduler = scheduler
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self.control_model = control_model
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
def set_use_memory_efficient_attention_xformers(
self, module: torch.nn.Module, valid: bool, attention_op: Optional[Callable] = None
) -> None:
# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
for child in module.children():
fn_recursive_set_mem_eff(child)
fn_recursive_set_mem_eff(module)
def set_attention_slice(self, module: torch.nn.Module, slice_size: Optional[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size)
def _adjust_memory_efficient_attention(self, model, latents: torch.Tensor):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
config = InvokeAIAppConfig.get_config()
if config.attention_type == "xformers":
self.enable_xformers_memory_efficient_attention()
self.set_use_memory_efficient_attention_xformers(model, True)
return
elif config.attention_type == "sliced":
slice_size = config.attention_slice_size
@ -300,29 +243,32 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
slice_size = auto_detect_slice_size(latents)
elif slice_size == "balanced":
slice_size = "auto"
self.enable_attention_slicing(slice_size=slice_size)
self.set_attention_slice(model, slice_size=slice_size)
return
elif config.attention_type == "normal":
self.disable_attention_slicing()
self.set_attention_slice(model, slice_size=None)
return
elif config.attention_type == "torch-sdp":
raise Exception("torch-sdp attention slicing not yet implemented")
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise Exception("torch-sdp requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
model.set_attn_processor(AttnProcessor2_0())
return
# the remainder if this code is called when attention_type=='auto'
if self.unet.device.type == "cuda":
if model.device.type == "cuda":
if is_xformers_available() and not config.disable_xformers:
self.enable_xformers_memory_efficient_attention()
self.set_use_memory_efficient_attention_xformers(model, True)
return
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enable sdp automatically
model.set_attn_processor(AttnProcessor2_0())
return
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
if model.device.type == "cpu" or model.device.type == "mps":
mem_free = psutil.virtual_memory().free
elif self.unet.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
elif model.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(model.device))
else:
raise ValueError(f"unrecognized device {self.unet.device}")
raise ValueError(f"unrecognized device {model.device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
@ -335,15 +281,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
* bytes_per_element_needed_for_baddbmm_duplication
)
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
self.enable_attention_slicing(slice_size="max")
self.set_attention_slice(model, slice_size="max")
elif torch.backends.mps.is_available():
# diffusers recommends always enabling for mps
self.enable_attention_slicing(slice_size="max")
self.set_attention_slice(model, slice_size="max")
else:
self.disable_attention_slicing()
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
raise Exception("Should not be called")
self.set_attention_slice(model, slice_size=None)
def latents_from_embeddings(
self,
@ -429,7 +372,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
control_data: List[ControlNetData] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
):
self._adjust_memory_efficient_attention(latents)
self._adjust_memory_efficient_attention(self.unet, latents)
if control_data is not None:
for control in control_data:
self._adjust_memory_efficient_attention(control.model, latents)
if additional_guidance is None:
additional_guidance = []
@ -457,7 +404,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
)
# print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
for i, t in enumerate(tqdm(timesteps)):
batched_t = t.expand(batch_size)
step_output = self.step(
batched_t,