mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove requirements to diffusers pipeline, add support for torch-sdp, apply attention to controlnet models too
This commit is contained in:
parent
a74e2108bb
commit
6bb657b3f3
@ -286,30 +286,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
unet,
|
||||
scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
# TODO:
|
||||
# configure_model_padding(
|
||||
# unet,
|
||||
# self.seamless,
|
||||
# self.seamless_axes,
|
||||
# )
|
||||
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
def __init__(self):
|
||||
self.block_out_channels = [0]
|
||||
|
||||
def __init__(self):
|
||||
self.config = FakeVae.FakeVaeConfig()
|
||||
|
||||
return StableDiffusionGeneratorPipeline(
|
||||
vae=FakeVae(), # TODO: oh...
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
def prep_control_data(
|
||||
|
@ -12,19 +12,12 @@ import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from .diffusion import (
|
||||
@ -34,6 +27,7 @@ from .diffusion import (
|
||||
BasicConditioningInfo,
|
||||
)
|
||||
from ..util import normalize_device, auto_detect_slice_size
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -205,94 +199,43 @@ class ConditioningData:
|
||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||
r"""
|
||||
Output class for InvokeAI's Stable Diffusion pipeline.
|
||||
|
||||
Args:
|
||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||
after generation completes. Optional.
|
||||
"""
|
||||
attention_map_saver: Optional[AttentionMapSaver]
|
||||
|
||||
|
||||
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline.
|
||||
Hopefully future versions of diffusers provide access to more of these functions so that we don't
|
||||
need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
class StableDiffusionGeneratorPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||
requires_safety_checker: bool = False,
|
||||
control_model: ControlNetModel = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
safety_checker,
|
||||
feature_extractor,
|
||||
requires_safety_checker,
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
# FIXME: can't currently register control module
|
||||
# control_model=control_model,
|
||||
)
|
||||
self.unet = unet
|
||||
self.scheduler = scheduler
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||
self.control_model = control_model
|
||||
|
||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, module: torch.nn.Module, valid: bool, attention_op: Optional[Callable] = None
|
||||
) -> None:
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
||||
# gets the message
|
||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_mem_eff(child)
|
||||
|
||||
fn_recursive_set_mem_eff(module)
|
||||
|
||||
|
||||
def set_attention_slice(self, module: torch.nn.Module, slice_size: Optional[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size)
|
||||
|
||||
def _adjust_memory_efficient_attention(self, model, latents: torch.Tensor):
|
||||
"""
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
if config.attention_type == "xformers":
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
self.set_use_memory_efficient_attention_xformers(model, True)
|
||||
return
|
||||
elif config.attention_type == "sliced":
|
||||
slice_size = config.attention_slice_size
|
||||
@ -300,29 +243,32 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
slice_size = auto_detect_slice_size(latents)
|
||||
elif slice_size == "balanced":
|
||||
slice_size = "auto"
|
||||
self.enable_attention_slicing(slice_size=slice_size)
|
||||
self.set_attention_slice(model, slice_size=slice_size)
|
||||
return
|
||||
elif config.attention_type == "normal":
|
||||
self.disable_attention_slicing()
|
||||
self.set_attention_slice(model, slice_size=None)
|
||||
return
|
||||
elif config.attention_type == "torch-sdp":
|
||||
raise Exception("torch-sdp attention slicing not yet implemented")
|
||||
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
raise Exception("torch-sdp requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
model.set_attn_processor(AttnProcessor2_0())
|
||||
return
|
||||
|
||||
# the remainder if this code is called when attention_type=='auto'
|
||||
if self.unet.device.type == "cuda":
|
||||
if model.device.type == "cuda":
|
||||
if is_xformers_available() and not config.disable_xformers:
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
self.set_use_memory_efficient_attention_xformers(model, True)
|
||||
return
|
||||
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
# diffusers enable sdp automatically
|
||||
model.set_attn_processor(AttnProcessor2_0())
|
||||
return
|
||||
|
||||
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
||||
if model.device.type == "cpu" or model.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.unet.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
|
||||
elif model.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(model.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {self.unet.device}")
|
||||
raise ValueError(f"unrecognized device {model.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||
@ -335,15 +281,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
* bytes_per_element_needed_for_baddbmm_duplication
|
||||
)
|
||||
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
|
||||
self.enable_attention_slicing(slice_size="max")
|
||||
self.set_attention_slice(model, slice_size="max")
|
||||
elif torch.backends.mps.is_available():
|
||||
# diffusers recommends always enabling for mps
|
||||
self.enable_attention_slicing(slice_size="max")
|
||||
self.set_attention_slice(model, slice_size="max")
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
raise Exception("Should not be called")
|
||||
self.set_attention_slice(model, slice_size=None)
|
||||
|
||||
def latents_from_embeddings(
|
||||
self,
|
||||
@ -429,7 +372,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
control_data: List[ControlNetData] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
self._adjust_memory_efficient_attention(self.unet, latents)
|
||||
if control_data is not None:
|
||||
for control in control_data:
|
||||
self._adjust_memory_efficient_attention(control.model, latents)
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
@ -457,7 +404,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
|
||||
# print("timesteps:", timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
batched_t = t.expand(batch_size)
|
||||
step_output = self.step(
|
||||
batched_t,
|
||||
|
Loading…
Reference in New Issue
Block a user