Compare commits

...

2 Commits

2 changed files with 60 additions and 148 deletions

View File

@ -286,30 +286,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
unet, unet,
scheduler, scheduler,
) -> StableDiffusionGeneratorPipeline: ) -> StableDiffusionGeneratorPipeline:
# TODO:
# configure_model_padding(
# unet,
# self.seamless,
# self.seamless_axes,
# )
class FakeVae:
class FakeVaeConfig:
def __init__(self):
self.block_out_channels = [0]
def __init__(self):
self.config = FakeVae.FakeVaeConfig()
return StableDiffusionGeneratorPipeline( return StableDiffusionGeneratorPipeline(
vae=FakeVae(), # TODO: oh...
text_encoder=None,
tokenizer=None,
unet=unet, unet=unet,
scheduler=scheduler, scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
) )
def prep_control_data( def prep_control_data(

View File

@ -12,19 +12,12 @@ import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.models.attention_processor import AttnProcessor2_0
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput from diffusers.utils.outputs import BaseOutput
from pydantic import Field from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from .diffusion import ( from .diffusion import (
@ -34,6 +27,7 @@ from .diffusion import (
BasicConditioningInfo, BasicConditioningInfo,
) )
from ..util import normalize_device, auto_detect_slice_size from ..util import normalize_device, auto_detect_slice_size
from tqdm.auto import tqdm
@dataclass @dataclass
@ -205,145 +199,80 @@ class ConditioningData:
return dataclasses.replace(self, scheduler_args=scheduler_args) return dataclasses.replace(self, scheduler_args=scheduler_args)
@dataclass class StableDiffusionGeneratorPipeline:
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r"""
Output class for InvokeAI's Stable Diffusion pipeline.
Args:
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
after generation completes. Optional.
"""
attention_map_saver: Optional[AttentionMapSaver]
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline.
Hopefully future versions of diffusers provide access to more of these functions so that we don't
need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__( def __init__(
self, self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
control_model: ControlNetModel = None,
): ):
super().__init__( self.unet = unet
vae, self.scheduler = scheduler
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
requires_safety_checker,
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
# FIXME: can't currently register control module
# control_model=control_model,
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self.control_model = control_model
def _adjust_memory_efficient_attention(self, latents: torch.Tensor): def set_use_memory_efficient_attention_xformers(
self, module: torch.nn.Module, valid: bool, attention_op: Optional[Callable] = None
) -> None:
# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
for child in module.children():
fn_recursive_set_mem_eff(child)
fn_recursive_set_mem_eff(module)
def set_attention_slice(self, module: torch.nn.Module, slice_size: Optional[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size)
def _adjust_memory_efficient_attention(self, model, latents: torch.Tensor):
""" """
if xformers is available, use it, otherwise use sliced attention. if xformers is available, use it, otherwise use sliced attention.
""" """
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
if config.attention_type == "xformers": if config.attention_type == "xformers":
self.enable_xformers_memory_efficient_attention() self.set_use_memory_efficient_attention_xformers(model, True)
return
elif config.attention_type == "sliced": elif config.attention_type == "sliced":
slice_size = config.attention_slice_size slice_size = config.attention_slice_size
if slice_size == "auto": if slice_size == "auto":
slice_size = auto_detect_slice_size(latents) slice_size = auto_detect_slice_size(latents)
elif slice_size == "balanced":
if slice_size == "balanced":
slice_size = "auto" slice_size = "auto"
self.enable_attention_slicing(slice_size=slice_size) self.set_attention_slice(model, slice_size=slice_size)
return
elif config.attention_type == "normal": elif config.attention_type == "normal":
self.disable_attention_slicing() self.set_attention_slice(model, slice_size=None)
return
elif config.attention_type == "torch-sdp": elif config.attention_type == "torch-sdp":
raise Exception("torch-sdp attention slicing not yet implemented") if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise Exception("torch-sdp requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
model.set_attn_processor(AttnProcessor2_0())
# the remainder if this code is called when attention_type=='auto' else: # auto
if self.unet.device.type == "cuda": if model.device.type == "cuda":
if is_xformers_available() and not config.disable_xformers: if is_xformers_available() and not config.disable_xformers:
self.enable_xformers_memory_efficient_attention() self.set_use_memory_efficient_attention_xformers(model, True)
return
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enable sdp automatically
return
if self.unet.device.type == "cpu" or self.unet.device.type == "mps": elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
mem_free = psutil.virtual_memory().free model.set_attn_processor(AttnProcessor2_0())
elif self.unet.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
else:
raise ValueError(f"unrecognized device {self.unet.device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
max_size_required_for_baddbmm = (
16
* latents.size(dim=2)
* latents.size(dim=3)
* latents.size(dim=2)
* latents.size(dim=3)
* bytes_per_element_needed_for_baddbmm_duplication
)
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
self.enable_attention_slicing(slice_size="max")
elif torch.backends.mps.is_available():
# diffusers recommends always enabling for mps
self.enable_attention_slicing(slice_size="max")
else:
self.disable_attention_slicing()
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False): else:
raise Exception("Should not be called") if model.device.type == "cpu" or model.device.type == "mps":
mem_free = psutil.virtual_memory().free
elif model.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(model.device))
else:
raise ValueError(f"unrecognized device {model.device}")
slice_size = auto_detect_slice_size(latents)
if slice_size == "balanced":
slice_size = "auto"
self.set_attention_slice(model, slice_size=slice_size)
def latents_from_embeddings( def latents_from_embeddings(
self, self,
@ -429,7 +358,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
): ):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(self.unet, latents)
if control_data is not None:
for control in control_data:
self._adjust_memory_efficient_attention(control.model, latents)
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
@ -457,7 +390,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
) )
# print("timesteps:", timesteps) # print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(tqdm(timesteps)):
batched_t = t.expand(batch_size) batched_t = t.expand(batch_size)
step_output = self.step( step_output = self.step(
batched_t, batched_t,