diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 6357c1ac7b..659a349607 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -286,30 +286,9 @@ class DenoiseLatentsInvocation(BaseInvocation): unet, scheduler, ) -> StableDiffusionGeneratorPipeline: - # TODO: - # configure_model_padding( - # unet, - # self.seamless, - # self.seamless_axes, - # ) - - class FakeVae: - class FakeVaeConfig: - def __init__(self): - self.block_out_channels = [0] - - def __init__(self): - self.config = FakeVae.FakeVaeConfig() - return StableDiffusionGeneratorPipeline( - vae=FakeVae(), # TODO: oh... - text_encoder=None, - tokenizer=None, unet=unet, scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, ) def prep_control_data( diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index d88313f455..89a22bb416 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -12,19 +12,12 @@ import torch import torchvision.transforms as T from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models.controlnet import ControlNetModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( - StableDiffusionPipeline, -) -from diffusers.pipelines.stable_diffusion.safety_checker import ( - StableDiffusionSafetyChecker, -) +from diffusers.models.attention_processor import AttnProcessor2_0 from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.outputs import BaseOutput from pydantic import Field -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from invokeai.app.services.config import InvokeAIAppConfig from .diffusion import ( @@ -34,6 +27,7 @@ from .diffusion import ( BasicConditioningInfo, ) from ..util import normalize_device, auto_detect_slice_size +from tqdm.auto import tqdm @dataclass @@ -205,94 +199,43 @@ class ConditioningData: return dataclasses.replace(self, scheduler_args=scheduler_args) -@dataclass -class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput): - r""" - Output class for InvokeAI's Stable Diffusion pipeline. - - Args: - attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user - after generation completes. Optional. - """ - attention_map_saver: Optional[AttentionMapSaver] - - -class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline. - Hopefully future versions of diffusers provide access to more of these functions so that we don't - need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384 - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - +class StableDiffusionGeneratorPipeline: def __init__( self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, - safety_checker: Optional[StableDiffusionSafetyChecker], - feature_extractor: Optional[CLIPFeatureExtractor], - requires_safety_checker: bool = False, - control_model: ControlNetModel = None, ): - super().__init__( - vae, - text_encoder, - tokenizer, - unet, - scheduler, - safety_checker, - feature_extractor, - requires_safety_checker, - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - # FIXME: can't currently register control module - # control_model=control_model, - ) + self.unet = unet + self.scheduler = scheduler self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) - self.control_model = control_model - def _adjust_memory_efficient_attention(self, latents: torch.Tensor): + def set_use_memory_efficient_attention_xformers( + self, module: torch.nn.Module, valid: bool, attention_op: Optional[Callable] = None + ) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid, attention_op) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + fn_recursive_set_mem_eff(module) + + + def set_attention_slice(self, module: torch.nn.Module, slice_size: Optional[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size) + + def _adjust_memory_efficient_attention(self, model, latents: torch.Tensor): """ if xformers is available, use it, otherwise use sliced attention. """ config = InvokeAIAppConfig.get_config() if config.attention_type == "xformers": - self.enable_xformers_memory_efficient_attention() + self.set_use_memory_efficient_attention_xformers(model, True) return elif config.attention_type == "sliced": slice_size = config.attention_slice_size @@ -300,29 +243,32 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): slice_size = auto_detect_slice_size(latents) elif slice_size == "balanced": slice_size = "auto" - self.enable_attention_slicing(slice_size=slice_size) + self.set_attention_slice(model, slice_size=slice_size) return elif config.attention_type == "normal": - self.disable_attention_slicing() + self.set_attention_slice(model, slice_size=None) return elif config.attention_type == "torch-sdp": - raise Exception("torch-sdp attention slicing not yet implemented") + if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise Exception("torch-sdp requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + model.set_attn_processor(AttnProcessor2_0()) + return # the remainder if this code is called when attention_type=='auto' - if self.unet.device.type == "cuda": + if model.device.type == "cuda": if is_xformers_available() and not config.disable_xformers: - self.enable_xformers_memory_efficient_attention() + self.set_use_memory_efficient_attention_xformers(model, True) return elif hasattr(torch.nn.functional, "scaled_dot_product_attention"): - # diffusers enable sdp automatically + model.set_attn_processor(AttnProcessor2_0()) return - if self.unet.device.type == "cpu" or self.unet.device.type == "mps": + if model.device.type == "cpu" or model.device.type == "mps": mem_free = psutil.virtual_memory().free - elif self.unet.device.type == "cuda": - mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device)) + elif model.device.type == "cuda": + mem_free, _ = torch.cuda.mem_get_info(normalize_device(model.device)) else: - raise ValueError(f"unrecognized device {self.unet.device}") + raise ValueError(f"unrecognized device {model.device}") # input tensor of [1, 4, h/8, w/8] # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 @@ -335,15 +281,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): * bytes_per_element_needed_for_baddbmm_duplication ) if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code - self.enable_attention_slicing(slice_size="max") + self.set_attention_slice(model, slice_size="max") elif torch.backends.mps.is_available(): # diffusers recommends always enabling for mps - self.enable_attention_slicing(slice_size="max") + self.set_attention_slice(model, slice_size="max") else: - self.disable_attention_slicing() - - def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False): - raise Exception("Should not be called") + self.set_attention_slice(model, slice_size=None) def latents_from_embeddings( self, @@ -429,7 +372,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): control_data: List[ControlNetData] = None, callback: Callable[[PipelineIntermediateState], None] = None, ): - self._adjust_memory_efficient_attention(latents) + self._adjust_memory_efficient_attention(self.unet, latents) + if control_data is not None: + for control in control_data: + self._adjust_memory_efficient_attention(control.model, latents) + if additional_guidance is None: additional_guidance = [] @@ -457,7 +404,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ) # print("timesteps:", timesteps) - for i, t in enumerate(self.progress_bar(timesteps)): + for i, t in enumerate(tqdm(timesteps)): batched_t = t.expand(batch_size) step_output = self.step( batched_t,