2023-02-28 05:37:13 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import dataclasses
|
|
|
|
import inspect
|
|
|
|
from dataclasses import dataclass, field
|
2023-08-17 22:45:25 +00:00
|
|
|
from typing import Any, Callable, List, Optional, Union
|
2023-02-28 05:37:13 +00:00
|
|
|
|
2023-08-06 03:35:38 +00:00
|
|
|
import einops
|
2023-09-05 00:04:46 +00:00
|
|
|
import PIL.Image
|
2023-02-28 05:37:13 +00:00
|
|
|
import psutil
|
|
|
|
import torch
|
|
|
|
import torchvision.transforms as T
|
|
|
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
2023-07-03 14:55:04 +00:00
|
|
|
from diffusers.models.controlnet import ControlNetModel
|
2023-02-28 05:37:13 +00:00
|
|
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
2023-09-05 00:04:46 +00:00
|
|
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
|
|
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
2023-02-28 05:37:13 +00:00
|
|
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
|
|
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
|
|
|
from diffusers.utils.import_utils import is_xformers_available
|
|
|
|
from diffusers.utils.outputs import BaseOutput
|
2023-08-06 03:35:38 +00:00
|
|
|
from pydantic import Field
|
2023-02-28 05:37:13 +00:00
|
|
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|
|
|
|
2023-05-26 00:41:26 +00:00
|
|
|
from invokeai.app.services.config import InvokeAIAppConfig
|
2023-09-01 11:40:30 +00:00
|
|
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2023-09-05 00:04:46 +00:00
|
|
|
from ..util import auto_detect_slice_size, normalize_device
|
|
|
|
from .diffusion import AttentionMapSaver, BasicConditioningInfo, InvokeAIDiffuserComponent, PostprocessingSettings
|
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-02-28 05:37:13 +00:00
|
|
|
@dataclass
|
|
|
|
class PipelineIntermediateState:
|
|
|
|
step: int
|
2023-08-09 00:34:25 +00:00
|
|
|
order: int
|
|
|
|
total_steps: int
|
2023-02-28 05:37:13 +00:00
|
|
|
timestep: int
|
|
|
|
latents: torch.Tensor
|
|
|
|
predicted_original: Optional[torch.Tensor] = None
|
|
|
|
attention_map_saver: Optional[AttentionMapSaver] = None
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class AddsMaskLatents:
|
|
|
|
"""Add the channels required for inpainting model input.
|
|
|
|
|
|
|
|
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
|
|
|
|
and the latent encoding of the base image.
|
|
|
|
|
|
|
|
This class assumes the same mask and base image should apply to all items in the batch.
|
|
|
|
"""
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2023-02-28 05:37:13 +00:00
|
|
|
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
|
|
|
mask: torch.Tensor
|
|
|
|
initial_image_latents: torch.Tensor
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
def __call__(
|
2023-05-30 07:26:07 +00:00
|
|
|
self,
|
|
|
|
latents: torch.Tensor,
|
|
|
|
t: torch.Tensor,
|
|
|
|
text_embeddings: torch.Tensor,
|
|
|
|
**kwargs,
|
2023-03-03 06:02:00 +00:00
|
|
|
) -> torch.Tensor:
|
2023-02-28 05:37:13 +00:00
|
|
|
model_input = self.add_mask_channels(latents)
|
2023-05-30 07:26:07 +00:00
|
|
|
return self.forward(model_input, t, text_embeddings, **kwargs)
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
def add_mask_channels(self, latents):
|
|
|
|
batch_size = latents.size(0)
|
|
|
|
# duplicate mask and latents for each batch
|
2023-03-03 06:02:00 +00:00
|
|
|
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
|
|
|
image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
2023-02-28 05:37:13 +00:00
|
|
|
# add mask and image as additional channels
|
2023-03-03 06:02:00 +00:00
|
|
|
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
|
2023-02-28 05:37:13 +00:00
|
|
|
return model_input
|
|
|
|
|
|
|
|
|
|
|
|
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
|
2023-03-03 06:02:00 +00:00
|
|
|
return isinstance(b, torch.Tensor) and (a.size() == b.size())
|
|
|
|
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class AddsMaskGuidance:
|
|
|
|
mask: torch.FloatTensor
|
|
|
|
mask_latents: torch.FloatTensor
|
|
|
|
scheduler: SchedulerMixin
|
2023-08-14 02:58:08 +00:00
|
|
|
noise: torch.Tensor
|
2023-02-28 05:37:13 +00:00
|
|
|
|
2023-07-03 14:55:04 +00:00
|
|
|
def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
|
2023-02-28 05:37:13 +00:00
|
|
|
output_class = step_output.__class__ # We'll create a new one with masked data.
|
|
|
|
|
|
|
|
# The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
|
|
|
|
# It's reasonable to assume the first thing is prev_sample, but then does it have other things
|
|
|
|
# like pred_original_sample? Should we apply the mask to them too?
|
|
|
|
# But what if there's just some other random field?
|
|
|
|
prev_sample = step_output[0]
|
|
|
|
# Mask anything that has the same shape as prev_sample, return others as-is.
|
|
|
|
return output_class(
|
2023-03-03 06:02:00 +00:00
|
|
|
{
|
|
|
|
k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v)
|
|
|
|
for k, v in step_output.items()
|
|
|
|
}
|
2023-02-28 05:37:13 +00:00
|
|
|
)
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
def _t_for_field(self, field_name: str, t):
|
2023-02-28 05:37:13 +00:00
|
|
|
if field_name == "pred_original_sample":
|
2023-07-13 03:06:03 +00:00
|
|
|
return self.scheduler.timesteps[-1]
|
2023-02-28 05:37:13 +00:00
|
|
|
return t
|
|
|
|
|
|
|
|
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
|
|
|
batch_size = latents.size(0)
|
2023-03-03 06:02:00 +00:00
|
|
|
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
2023-02-28 05:37:13 +00:00
|
|
|
if t.dim() == 0:
|
|
|
|
# some schedulers expect t to be one-dimensional.
|
|
|
|
# TODO: file diffusers bug about inconsistency?
|
2023-03-03 06:02:00 +00:00
|
|
|
t = einops.repeat(t, "-> batch", batch=batch_size)
|
2023-08-13 16:50:48 +00:00
|
|
|
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
|
|
|
# get very confused about what is happening from step to step when we do that.
|
2023-08-14 02:58:08 +00:00
|
|
|
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
|
2023-08-13 16:50:48 +00:00
|
|
|
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
|
|
|
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
2023-03-03 06:02:00 +00:00
|
|
|
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
|
|
|
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
2023-02-28 05:37:13 +00:00
|
|
|
return masked_input
|
|
|
|
|
|
|
|
|
|
|
|
def trim_to_multiple_of(*args, multiple_of=8):
|
|
|
|
return tuple((x - x % multiple_of) for x in args)
|
|
|
|
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = True, multiple_of=8) -> torch.FloatTensor:
|
2023-02-28 05:37:13 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
:param image: input image
|
|
|
|
:param normalize: scale the range to [-1, 1] instead of [0, 1]
|
|
|
|
:param multiple_of: resize the input so both dimensions are a multiple of this
|
|
|
|
"""
|
2023-03-10 01:31:05 +00:00
|
|
|
w, h = trim_to_multiple_of(*image.size, multiple_of=multiple_of)
|
2023-03-03 06:02:00 +00:00
|
|
|
transformation = T.Compose(
|
|
|
|
[
|
2023-08-27 23:54:39 +00:00
|
|
|
T.Resize((h, w), T.InterpolationMode.LANCZOS, antialias=True),
|
2023-03-03 06:02:00 +00:00
|
|
|
T.ToTensor(),
|
|
|
|
]
|
|
|
|
)
|
2023-02-28 05:37:13 +00:00
|
|
|
tensor = transformation(image)
|
|
|
|
if normalize:
|
|
|
|
tensor = tensor * 2.0 - 1.0
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
def is_inpainting_model(unet: UNet2DConditionModel):
|
|
|
|
return unet.conv_in.in_channels == 9
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2023-05-12 08:43:47 +00:00
|
|
|
@dataclass
|
|
|
|
class ControlNetData:
|
|
|
|
model: ControlNetModel = Field(default=None)
|
2023-06-11 09:00:39 +00:00
|
|
|
image_tensor: torch.Tensor = Field(default=None)
|
|
|
|
weight: Union[float, List[float]] = Field(default=1.0)
|
2023-05-12 11:01:35 +00:00
|
|
|
begin_step_percent: float = Field(default=0.0)
|
|
|
|
end_step_percent: float = Field(default=1.0)
|
2023-06-14 04:08:34 +00:00
|
|
|
control_mode: str = Field(default="balanced")
|
2023-07-20 02:21:17 +00:00
|
|
|
resize_mode: str = Field(default="just_resize")
|
2023-06-14 04:08:34 +00:00
|
|
|
|
2023-02-28 05:37:13 +00:00
|
|
|
|
2023-09-01 06:07:15 +00:00
|
|
|
@dataclass
|
|
|
|
class IPAdapterData:
|
|
|
|
ip_adapter_model: str = Field(default=None)
|
|
|
|
image_encoder_model: str = Field(default=None)
|
|
|
|
image: PIL.Image = Field(default=None)
|
|
|
|
# TODO: change to polymorphic so can do different weights per step (once implemented...)
|
|
|
|
# weight: Union[float, List[float]] = Field(default=1.0)
|
|
|
|
weight: float = Field(default=1.0)
|
|
|
|
|
2023-09-04 23:37:12 +00:00
|
|
|
|
2023-07-16 18:48:43 +00:00
|
|
|
@dataclass
|
2023-02-28 05:37:13 +00:00
|
|
|
class ConditioningData:
|
2023-08-13 16:50:48 +00:00
|
|
|
unconditioned_embeddings: BasicConditioningInfo
|
|
|
|
text_embeddings: BasicConditioningInfo
|
Feat/easy param (#3504)
* Testing change to LatentsToText to allow setting different cfg_scale values per diffusion step.
* Adding first attempt at float param easing node, using Penner easing functions.
* Core implementation of ControlNet and MultiControlNet.
* Added support for ControlNet and MultiControlNet to legacy non-nodal Txt2Img in backend/generator. Although backend/generator will likely disappear by v3.x, right now they are very useful for testing core ControlNet and MultiControlNet functionality while node codebase is rapidly evolving.
* Added example of using ControlNet with legacy Txt2Img generator
* Resolving rebase conflict
* Added first controlnet preprocessor node for canny edge detection.
* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node
* Switching to ControlField for output from controlnet nodes.
* Resolving conflicts in rebase to origin/main
* Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke())
* changes to base class for controlnet nodes
* Added HED, LineArt, and OpenPose ControlNet nodes
* Added an additional "raw_processed_image" output port to controlnets, mainly so could route ImageField to a ShowImage node
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* More rebase repair.
* Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this...
* Fixed use of ControlNet control_weight parameter
* Fixed lint-ish formatting error
* Core implementation of ControlNet and MultiControlNet.
* Added first controlnet preprocessor node for canny edge detection.
* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node
* Switching to ControlField for output from controlnet nodes.
* Refactored controlnet node to output ControlField that bundles control info.
* changes to base class for controlnet nodes
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* Cleaning up TextToLatent arg testing
* Cleaning up mistakes after rebase.
* Removed last bits of dtype and and device hardwiring from controlnet section
* Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.
* Added support for specifying which step iteration to start using
each ControlNet, and which step to end using each controlnet (specified as fraction of total steps)
* Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input.
* Added dependency on controlnet-aux v0.0.3
* Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it.
* Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names.
* Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle.
* Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier.
* Cleaning up after ControlNet refactor in TextToLatentsInvocation
* Extended node-based ControlNet support to LatentsToLatentsInvocation.
* chore(ui): regen api client
* fix(ui): add value to conditioning field
* fix(ui): add control field type
* fix(ui): fix node ui type hints
* fix(nodes): controlnet input accepts list or single controlnet
* Moved to controlnet_aux v0.0.4, reinstated Zoe controlnet preprocessor. Also in pyproject.toml had to specify downgrade of timm to 0.6.13 _after_ controlnet-aux installs timm >= 0.9.2, because timm >0.6.13 breaks Zoe preprocessor.
* Core implementation of ControlNet and MultiControlNet.
* Added first controlnet preprocessor node for canny edge detection.
* Switching to ControlField for output from controlnet nodes.
* Resolving conflicts in rebase to origin/main
* Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke())
* changes to base class for controlnet nodes
* Added HED, LineArt, and OpenPose ControlNet nodes
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this...
* Fixed use of ControlNet control_weight parameter
* Core implementation of ControlNet and MultiControlNet.
* Added first controlnet preprocessor node for canny edge detection.
* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node
* Switching to ControlField for output from controlnet nodes.
* Refactored controlnet node to output ControlField that bundles control info.
* changes to base class for controlnet nodes
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* Cleaning up TextToLatent arg testing
* Cleaning up mistakes after rebase.
* Removed last bits of dtype and and device hardwiring from controlnet section
* Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.
* Added support for specifying which step iteration to start using
each ControlNet, and which step to end using each controlnet (specified as fraction of total steps)
* Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input.
* Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it.
* Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names.
* Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle.
* Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier.
* Cleaning up after ControlNet refactor in TextToLatentsInvocation
* Extended node-based ControlNet support to LatentsToLatentsInvocation.
* chore(ui): regen api client
* fix(ui): fix node ui type hints
* fix(nodes): controlnet input accepts list or single controlnet
* Added Mediapipe image processor for use as ControlNet preprocessor.
Also hacked in ability to specify HF subfolder when loading ControlNet models from string.
* Fixed bug where MediapipFaceProcessorInvocation was ignoring max_faces and min_confidence params.
* Added nodes for float params: ParamFloatInvocation and FloatCollectionOutput. Also added FloatOutput.
* Added mediapipe install requirement. Should be able to remove once controlnet_aux package adds mediapipe to its requirements.
* Added float to FIELD_TYPE_MAP ins constants.ts
* Progress toward improvement in fieldTemplateBuilder.ts getFieldType()
* Fixed controlnet preprocessors and controlnet handling in TextToLatents to work with revised Image services.
* Cleaning up from merge, re-adding cfg_scale to FIELD_TYPE_MAP
* Making sure cfg_scale of type list[float] can be used in image metadata, to support param easing for cfg_scale
* Fixed math for per-step param easing.
* Added option to show plot of param value at each step
* Just cleaning up after adding param easing plot option, removing vestigial code.
* Modified control_weight ControlNet param to be polistmorphic --
can now be either a single float weight applied for all steps, or a list of floats of size total_steps, that specifies weight for each step.
* Added more informative error message when _validat_edge() throws an error.
* Just improving parm easing bar chart title to include easing type.
* Added requirement for easing-functions package
* Taking out some diagnostic prints.
* Added option to use both easing function and mirror of easing function together.
* Fixed recently introduced problem (when pulled in main), triggered by num_steps in StepParamEasingInvocation not having a default value -- just added default.
---------
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-06-11 06:27:44 +00:00
|
|
|
guidance_scale: Union[float, List[float]]
|
2023-02-28 05:37:13 +00:00
|
|
|
"""
|
|
|
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
|
|
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
|
|
|
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
|
|
|
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
|
|
|
"""
|
|
|
|
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
|
|
|
|
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
|
|
|
"""
|
|
|
|
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
|
|
|
"""
|
|
|
|
postprocessing_settings: Optional[PostprocessingSettings] = None
|
|
|
|
|
|
|
|
@property
|
|
|
|
def dtype(self):
|
|
|
|
return self.text_embeddings.dtype
|
|
|
|
|
|
|
|
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
|
|
|
scheduler_args = dict(self.scheduler_args)
|
|
|
|
step_method = inspect.signature(scheduler.step)
|
|
|
|
for name, value in kwargs.items():
|
|
|
|
try:
|
|
|
|
step_method.bind_partial(**{name: value})
|
|
|
|
except TypeError:
|
|
|
|
# FIXME: don't silently discard arguments
|
|
|
|
pass # debug("%s does not accept argument named %r", scheduler, name)
|
|
|
|
else:
|
|
|
|
scheduler_args[name] = value
|
|
|
|
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2023-02-28 05:37:13 +00:00
|
|
|
@dataclass
|
|
|
|
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
|
|
|
r"""
|
|
|
|
Output class for InvokeAI's Stable Diffusion pipeline.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
|
|
|
after generation completes. Optional.
|
|
|
|
"""
|
|
|
|
attention_map_saver: Optional[AttentionMapSaver]
|
|
|
|
|
|
|
|
|
|
|
|
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|
|
|
r"""
|
|
|
|
Pipeline for text-to-image generation using Stable Diffusion.
|
|
|
|
|
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
|
|
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
|
|
|
|
|
|
|
Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline.
|
|
|
|
Hopefully future versions of diffusers provide access to more of these functions so that we don't
|
|
|
|
need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384
|
|
|
|
|
|
|
|
Args:
|
|
|
|
vae ([`AutoencoderKL`]):
|
|
|
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
|
|
|
text_encoder ([`CLIPTextModel`]):
|
|
|
|
Frozen text-encoder. Stable Diffusion uses the text portion of
|
|
|
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
|
|
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
|
|
|
tokenizer (`CLIPTokenizer`):
|
|
|
|
Tokenizer of class
|
|
|
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
|
|
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
|
|
|
scheduler ([`SchedulerMixin`]):
|
2023-03-10 01:33:06 +00:00
|
|
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
2023-02-28 05:37:13 +00:00
|
|
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
|
|
|
safety_checker ([`StableDiffusionSafetyChecker`]):
|
2023-03-10 01:33:06 +00:00
|
|
|
Classification module that estimates whether generated images could be considered offensive or harmful.
|
2023-02-28 05:37:13 +00:00
|
|
|
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
|
|
|
feature_extractor ([`CLIPFeatureExtractor`]):
|
|
|
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
vae: AutoencoderKL,
|
|
|
|
text_encoder: CLIPTextModel,
|
|
|
|
tokenizer: CLIPTokenizer,
|
|
|
|
unet: UNet2DConditionModel,
|
|
|
|
scheduler: KarrasDiffusionSchedulers,
|
|
|
|
safety_checker: Optional[StableDiffusionSafetyChecker],
|
|
|
|
feature_extractor: Optional[CLIPFeatureExtractor],
|
|
|
|
requires_safety_checker: bool = False,
|
2023-04-29 07:43:21 +00:00
|
|
|
control_model: ControlNetModel = None,
|
2023-02-28 05:37:13 +00:00
|
|
|
):
|
2023-03-03 06:02:00 +00:00
|
|
|
super().__init__(
|
|
|
|
vae,
|
|
|
|
text_encoder,
|
|
|
|
tokenizer,
|
|
|
|
unet,
|
|
|
|
scheduler,
|
|
|
|
safety_checker,
|
|
|
|
feature_extractor,
|
|
|
|
requires_safety_checker,
|
|
|
|
)
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
self.register_modules(
|
|
|
|
vae=vae,
|
|
|
|
text_encoder=text_encoder,
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
unet=unet,
|
|
|
|
scheduler=scheduler,
|
|
|
|
safety_checker=safety_checker,
|
|
|
|
feature_extractor=feature_extractor,
|
2023-04-29 07:43:21 +00:00
|
|
|
# FIXME: can't currently register control module
|
|
|
|
# control_model=control_model,
|
2023-02-28 05:37:13 +00:00
|
|
|
)
|
2023-03-03 06:02:00 +00:00
|
|
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
2023-04-29 07:43:21 +00:00
|
|
|
self.control_model = control_model
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
|
|
|
"""
|
|
|
|
if xformers is available, use it, otherwise use sliced attention.
|
|
|
|
"""
|
2023-05-26 00:41:26 +00:00
|
|
|
config = InvokeAIAppConfig.get_config()
|
2023-08-17 18:20:45 +00:00
|
|
|
if config.attention_type == "xformers":
|
|
|
|
self.enable_xformers_memory_efficient_attention()
|
|
|
|
return
|
|
|
|
elif config.attention_type == "sliced":
|
|
|
|
slice_size = config.attention_slice_size
|
2023-08-17 20:11:09 +00:00
|
|
|
if slice_size == "auto":
|
|
|
|
slice_size = auto_detect_slice_size(latents)
|
|
|
|
elif slice_size == "balanced":
|
|
|
|
slice_size = "auto"
|
2023-08-17 18:20:45 +00:00
|
|
|
self.enable_attention_slicing(slice_size=slice_size)
|
|
|
|
return
|
|
|
|
elif config.attention_type == "normal":
|
|
|
|
self.disable_attention_slicing()
|
|
|
|
return
|
|
|
|
elif config.attention_type == "torch-sdp":
|
|
|
|
raise Exception("torch-sdp attention slicing not yet implemented")
|
|
|
|
|
|
|
|
# the remainder if this code is called when attention_type=='auto'
|
2023-08-07 16:57:11 +00:00
|
|
|
if self.unet.device.type == "cuda":
|
|
|
|
if is_xformers_available() and not config.disable_xformers:
|
|
|
|
self.enable_xformers_memory_efficient_attention()
|
|
|
|
return
|
|
|
|
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
|
|
|
# diffusers enable sdp automatically
|
|
|
|
return
|
|
|
|
|
2023-08-10 01:32:16 +00:00
|
|
|
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
2023-08-07 16:57:11 +00:00
|
|
|
mem_free = psutil.virtual_memory().free
|
2023-08-10 01:32:16 +00:00
|
|
|
elif self.unet.device.type == "cuda":
|
|
|
|
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
|
2023-02-28 05:37:13 +00:00
|
|
|
else:
|
2023-08-10 01:32:16 +00:00
|
|
|
raise ValueError(f"unrecognized device {self.unet.device}")
|
2023-08-07 16:57:11 +00:00
|
|
|
# input tensor of [1, 4, h/8, w/8]
|
|
|
|
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
|
|
|
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
|
|
|
max_size_required_for_baddbmm = (
|
|
|
|
16
|
|
|
|
* latents.size(dim=2)
|
|
|
|
* latents.size(dim=3)
|
|
|
|
* latents.size(dim=2)
|
|
|
|
* latents.size(dim=3)
|
|
|
|
* bytes_per_element_needed_for_baddbmm_duplication
|
|
|
|
)
|
|
|
|
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
|
|
|
|
self.enable_attention_slicing(slice_size="max")
|
|
|
|
elif torch.backends.mps.is_available():
|
|
|
|
# diffusers recommends always enabling for mps
|
|
|
|
self.enable_attention_slicing(slice_size="max")
|
|
|
|
else:
|
|
|
|
self.disable_attention_slicing()
|
2023-02-28 05:37:13 +00:00
|
|
|
|
2023-03-10 01:34:41 +00:00
|
|
|
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
2023-08-08 17:49:01 +00:00
|
|
|
raise Exception("Should not be called")
|
2023-02-28 05:37:13 +00:00
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
def latents_from_embeddings(
|
|
|
|
self,
|
|
|
|
latents: torch.Tensor,
|
|
|
|
num_inference_steps: int,
|
|
|
|
conditioning_data: ConditioningData,
|
|
|
|
*,
|
2023-08-07 16:57:11 +00:00
|
|
|
noise: Optional[torch.Tensor],
|
2023-08-11 12:46:16 +00:00
|
|
|
timesteps: torch.Tensor,
|
|
|
|
init_timestep: torch.Tensor,
|
2023-03-03 06:02:00 +00:00
|
|
|
additional_guidance: List[Callable] = None,
|
|
|
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
2023-05-12 08:43:47 +00:00
|
|
|
control_data: List[ControlNetData] = None,
|
2023-09-01 06:07:15 +00:00
|
|
|
ip_adapter_data: IPAdapterData = None,
|
2023-08-08 15:50:36 +00:00
|
|
|
mask: Optional[torch.Tensor] = None,
|
2023-08-18 01:07:40 +00:00
|
|
|
masked_latents: Optional[torch.Tensor] = None,
|
2023-08-08 15:50:36 +00:00
|
|
|
seed: Optional[int] = None,
|
2023-03-03 06:02:00 +00:00
|
|
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
2023-08-11 12:46:16 +00:00
|
|
|
if init_timestep.shape[0] == 0:
|
|
|
|
return latents, None
|
2023-08-08 15:50:36 +00:00
|
|
|
|
|
|
|
if additional_guidance is None:
|
|
|
|
additional_guidance = []
|
|
|
|
|
|
|
|
orig_latents = latents.clone()
|
|
|
|
|
|
|
|
batch_size = latents.shape[0]
|
2023-08-13 21:20:01 +00:00
|
|
|
batched_t = init_timestep.expand(batch_size)
|
2023-08-08 15:50:36 +00:00
|
|
|
|
|
|
|
if noise is not None:
|
2023-08-13 09:28:39 +00:00
|
|
|
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
2023-08-08 15:50:36 +00:00
|
|
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
|
|
|
|
|
|
|
if mask is not None:
|
2023-08-16 17:28:33 +00:00
|
|
|
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
|
|
|
|
if noise is None:
|
|
|
|
noise = torch.randn(
|
|
|
|
orig_latents.shape,
|
|
|
|
dtype=torch.float32,
|
|
|
|
device="cpu",
|
|
|
|
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
|
|
|
|
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
|
|
|
|
|
|
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
|
|
|
latents = torch.lerp(
|
|
|
|
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
|
|
|
)
|
|
|
|
|
2023-08-08 15:50:36 +00:00
|
|
|
if is_inpainting_model(self.unet):
|
2023-08-18 01:07:40 +00:00
|
|
|
if masked_latents is None:
|
|
|
|
raise Exception("Source image required for inpaint mask when inpaint model used!")
|
2023-08-08 15:50:36 +00:00
|
|
|
|
2023-08-16 17:28:33 +00:00
|
|
|
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
|
|
|
self._unet_forward, mask, masked_latents
|
|
|
|
)
|
2023-08-08 15:50:36 +00:00
|
|
|
else:
|
|
|
|
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
|
|
|
|
|
|
|
|
try:
|
2023-08-14 00:35:15 +00:00
|
|
|
latents, attention_map_saver = self.generate_latents_from_embeddings(
|
2023-08-08 15:50:36 +00:00
|
|
|
latents,
|
|
|
|
timesteps,
|
|
|
|
conditioning_data,
|
|
|
|
additional_guidance=additional_guidance,
|
|
|
|
control_data=control_data,
|
2023-09-01 06:07:15 +00:00
|
|
|
ip_adapter_data=ip_adapter_data,
|
2023-08-08 15:50:36 +00:00
|
|
|
callback=callback,
|
|
|
|
)
|
|
|
|
finally:
|
|
|
|
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
|
|
|
|
|
|
|
# restore unmasked part
|
|
|
|
if mask is not None:
|
|
|
|
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
|
|
|
|
2023-08-14 00:35:15 +00:00
|
|
|
return latents, attention_map_saver
|
2023-02-28 05:37:13 +00:00
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
def generate_latents_from_embeddings(
|
|
|
|
self,
|
|
|
|
latents: torch.Tensor,
|
|
|
|
timesteps,
|
|
|
|
conditioning_data: ConditioningData,
|
|
|
|
*,
|
|
|
|
additional_guidance: List[Callable] = None,
|
2023-05-12 08:43:47 +00:00
|
|
|
control_data: List[ControlNetData] = None,
|
2023-09-01 06:07:15 +00:00
|
|
|
ip_adapter_data: List[IPAdapterData] = None,
|
2023-08-14 00:35:15 +00:00
|
|
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
2023-03-03 06:02:00 +00:00
|
|
|
):
|
2023-02-28 05:37:13 +00:00
|
|
|
self._adjust_memory_efficient_attention(latents)
|
|
|
|
if additional_guidance is None:
|
|
|
|
additional_guidance = []
|
2023-08-13 21:20:01 +00:00
|
|
|
|
|
|
|
batch_size = latents.shape[0]
|
|
|
|
attention_map_saver: Optional[AttentionMapSaver] = None
|
|
|
|
|
|
|
|
if timesteps.shape[0] == 0:
|
|
|
|
return latents, attention_map_saver
|
|
|
|
|
2023-09-01 06:07:15 +00:00
|
|
|
# print("ip_adapter_image: ", type(ip_adapter_image))
|
|
|
|
if ip_adapter_data is not None and len(ip_adapter_data) > 0:
|
|
|
|
ip_adapter_info = ip_adapter_data[0]
|
|
|
|
ip_adapter_image = ip_adapter_info.image
|
2023-08-29 13:31:24 +00:00
|
|
|
# initialize IPAdapter
|
|
|
|
print(" width:", ip_adapter_image.width, " height:", ip_adapter_image.height)
|
|
|
|
# FIXME:
|
|
|
|
# WARNING!
|
|
|
|
# IPAdapter constructor modifies UNet model in-place
|
|
|
|
# Adds additional cross-attention layers to UNet model for image embedding
|
2023-09-01 06:07:15 +00:00
|
|
|
# need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter
|
|
|
|
# and how to undo if ip_adapter_image is removed
|
|
|
|
# Should reimplement to use existing model management context etc.
|
2023-08-29 13:31:24 +00:00
|
|
|
#
|
2023-09-01 11:40:30 +00:00
|
|
|
if "sdxl" in ip_adapter_info.ip_adapter_model:
|
|
|
|
print("using IPAdapterXL")
|
2023-09-04 23:37:12 +00:00
|
|
|
ip_adapter = IPAdapterXL(
|
|
|
|
self, ip_adapter_info.image_encoder_model, ip_adapter_info.ip_adapter_model, self.unet.device
|
|
|
|
)
|
2023-09-01 11:40:30 +00:00
|
|
|
elif "plus" in ip_adapter_info.ip_adapter_model:
|
|
|
|
print("using IPAdapterPlus")
|
2023-09-04 23:37:12 +00:00
|
|
|
ip_adapter = IPAdapterPlus(
|
|
|
|
self, # IPAdapterPlus first arg is StableDiffusionPipeline
|
|
|
|
ip_adapter_info.image_encoder_model,
|
|
|
|
ip_adapter_info.ip_adapter_model,
|
|
|
|
self.unet.device,
|
|
|
|
num_tokens=16,
|
|
|
|
)
|
2023-09-01 11:40:30 +00:00
|
|
|
else:
|
|
|
|
print("using IPAdapter")
|
2023-09-04 23:37:12 +00:00
|
|
|
ip_adapter = IPAdapter(
|
|
|
|
self, # IPAdapter first arg is StableDiffusionPipeline
|
|
|
|
ip_adapter_info.image_encoder_model,
|
|
|
|
ip_adapter_info.ip_adapter_model,
|
|
|
|
self.unet.device,
|
|
|
|
)
|
2023-08-29 13:31:24 +00:00
|
|
|
# IP-Adapter ==> add additional cross-attention layers to UNet model here?
|
2023-09-01 06:07:15 +00:00
|
|
|
ip_adapter.set_scale(ip_adapter_info.weight)
|
2023-08-29 13:31:24 +00:00
|
|
|
print("ip_adapter:", ip_adapter)
|
|
|
|
|
|
|
|
# get image embedding from CLIP and ImageProjModel
|
|
|
|
print("getting image embeddings from IP-Adapter...")
|
2023-09-04 23:37:12 +00:00
|
|
|
num_samples = 1 # hardwiring for first pass
|
2023-08-29 13:31:24 +00:00
|
|
|
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_image)
|
|
|
|
print("image cond embeds shape:", image_prompt_embeds.shape)
|
|
|
|
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)
|
|
|
|
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
|
|
|
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
|
|
|
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
|
|
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
|
|
|
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
|
|
print("image cond embeds shape:", image_prompt_embeds.shape)
|
|
|
|
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)
|
|
|
|
|
|
|
|
# IP-Adapter: run IP-Adapter model here?
|
|
|
|
# and add output as additional cross-attention layers
|
|
|
|
text_prompt_embeds = conditioning_data.text_embeddings.embeds
|
|
|
|
uncond_text_prompt_embeds = conditioning_data.unconditioned_embeddings.embeds
|
|
|
|
print("text embeds shape:", text_prompt_embeds.shape)
|
|
|
|
concat_prompt_embeds = torch.cat([text_prompt_embeds, image_prompt_embeds], dim=1)
|
|
|
|
concat_uncond_prompt_embeds = torch.cat([uncond_text_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
|
|
|
print("concat embeds shape:", concat_prompt_embeds.shape)
|
|
|
|
conditioning_data.text_embeddings.embeds = concat_prompt_embeds
|
|
|
|
conditioning_data.unconditioned_embeddings.embeds = concat_uncond_prompt_embeds
|
|
|
|
else:
|
|
|
|
image_prompt_embeds = None
|
|
|
|
uncond_image_prompt_embeds = None
|
|
|
|
|
2023-02-28 05:37:13 +00:00
|
|
|
extra_conditioning_info = conditioning_data.extra
|
2023-03-03 06:02:00 +00:00
|
|
|
with self.invokeai_diffuser.custom_attention_context(
|
2023-05-12 01:13:18 +00:00
|
|
|
self.invokeai_diffuser.model,
|
|
|
|
extra_conditioning_info=extra_conditioning_info,
|
|
|
|
step_count=len(self.scheduler.timesteps),
|
2023-03-03 06:02:00 +00:00
|
|
|
):
|
2023-08-14 00:35:15 +00:00
|
|
|
if callback is not None:
|
2023-08-14 03:02:33 +00:00
|
|
|
callback(
|
|
|
|
PipelineIntermediateState(
|
|
|
|
step=-1,
|
|
|
|
order=self.scheduler.order,
|
|
|
|
total_steps=len(timesteps),
|
|
|
|
timestep=self.scheduler.config.num_train_timesteps,
|
|
|
|
latents=latents,
|
|
|
|
)
|
|
|
|
)
|
2023-08-06 02:05:25 +00:00
|
|
|
|
2023-05-12 08:43:47 +00:00
|
|
|
# print("timesteps:", timesteps)
|
2023-02-28 05:37:13 +00:00
|
|
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
2023-08-13 21:20:01 +00:00
|
|
|
batched_t = t.expand(batch_size)
|
2023-03-03 06:02:00 +00:00
|
|
|
step_output = self.step(
|
|
|
|
batched_t,
|
|
|
|
latents,
|
|
|
|
conditioning_data,
|
|
|
|
step_index=i,
|
|
|
|
total_step_count=len(timesteps),
|
|
|
|
additional_guidance=additional_guidance,
|
2023-05-12 08:43:47 +00:00
|
|
|
control_data=control_data,
|
2023-03-03 06:02:00 +00:00
|
|
|
)
|
2023-02-28 05:37:13 +00:00
|
|
|
latents = step_output.prev_sample
|
|
|
|
|
|
|
|
latents = self.invokeai_diffuser.do_latent_postprocessing(
|
|
|
|
postprocessing_settings=conditioning_data.postprocessing_settings,
|
|
|
|
latents=latents,
|
|
|
|
sigma=batched_t,
|
|
|
|
step_index=i,
|
2023-03-03 06:02:00 +00:00
|
|
|
total_step_count=len(timesteps),
|
2023-02-28 05:37:13 +00:00
|
|
|
)
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
predicted_original = getattr(step_output, "pred_original_sample", None)
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
# TODO resuscitate attention map saving
|
2023-03-03 06:02:00 +00:00
|
|
|
# if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
2023-02-28 05:37:13 +00:00
|
|
|
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
|
|
|
# attention_map_token_ids = range(1, eos_token_index)
|
|
|
|
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
|
|
|
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
|
|
|
|
2023-08-14 00:35:15 +00:00
|
|
|
if callback is not None:
|
2023-08-14 03:02:33 +00:00
|
|
|
callback(
|
|
|
|
PipelineIntermediateState(
|
|
|
|
step=i,
|
|
|
|
order=self.scheduler.order,
|
|
|
|
total_steps=len(timesteps),
|
|
|
|
timestep=int(t),
|
|
|
|
latents=latents,
|
|
|
|
predicted_original=predicted_original,
|
|
|
|
attention_map_saver=attention_map_saver,
|
|
|
|
)
|
|
|
|
)
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
return latents, attention_map_saver
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
2023-03-03 06:02:00 +00:00
|
|
|
def step(
|
|
|
|
self,
|
|
|
|
t: torch.Tensor,
|
|
|
|
latents: torch.Tensor,
|
|
|
|
conditioning_data: ConditioningData,
|
|
|
|
step_index: int,
|
|
|
|
total_step_count: int,
|
|
|
|
additional_guidance: List[Callable] = None,
|
2023-05-12 08:43:47 +00:00
|
|
|
control_data: List[ControlNetData] = None,
|
2023-03-03 06:02:00 +00:00
|
|
|
):
|
2023-02-28 05:37:13 +00:00
|
|
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
|
|
|
timestep = t[0]
|
|
|
|
if additional_guidance is None:
|
|
|
|
additional_guidance = []
|
|
|
|
|
|
|
|
# TODO: should this scaling happen here or inside self._unet_forward?
|
|
|
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
2023-08-06 02:05:25 +00:00
|
|
|
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
2023-06-11 09:00:39 +00:00
|
|
|
|
2023-05-12 11:01:35 +00:00
|
|
|
# default is no controlnet, so set controlnet processing output to None
|
2023-08-06 02:05:25 +00:00
|
|
|
controlnet_down_block_samples, controlnet_mid_block_sample = None, None
|
2023-05-12 08:43:47 +00:00
|
|
|
if control_data is not None:
|
2023-08-06 02:05:25 +00:00
|
|
|
controlnet_down_block_samples, controlnet_mid_block_sample = self.invokeai_diffuser.do_controlnet_step(
|
|
|
|
control_data=control_data,
|
|
|
|
sample=latent_model_input,
|
|
|
|
timestep=timestep,
|
|
|
|
step_index=step_index,
|
|
|
|
total_step_count=total_step_count,
|
|
|
|
conditioning_data=conditioning_data,
|
|
|
|
)
|
|
|
|
|
|
|
|
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
|
|
|
sample=latent_model_input,
|
2023-08-13 09:28:39 +00:00
|
|
|
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
2023-02-28 05:37:13 +00:00
|
|
|
step_index=step_index,
|
|
|
|
total_step_count=total_step_count,
|
2023-08-06 02:05:25 +00:00
|
|
|
conditioning_data=conditioning_data,
|
|
|
|
# extra:
|
|
|
|
down_block_additional_residuals=controlnet_down_block_samples, # from controlnet(s)
|
|
|
|
mid_block_additional_residual=controlnet_mid_block_sample, # from controlnet(s)
|
|
|
|
)
|
|
|
|
|
|
|
|
guidance_scale = conditioning_data.guidance_scale
|
|
|
|
if isinstance(guidance_scale, list):
|
|
|
|
guidance_scale = guidance_scale[step_index]
|
|
|
|
|
|
|
|
noise_pred = self.invokeai_diffuser._combine(
|
|
|
|
uc_noise_pred,
|
|
|
|
c_noise_pred,
|
|
|
|
guidance_scale,
|
2023-02-28 05:37:13 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# compute the previous noisy sample x_t -> x_t-1
|
2023-03-03 06:02:00 +00:00
|
|
|
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
2023-02-28 05:37:13 +00:00
|
|
|
|
2023-08-29 23:18:08 +00:00
|
|
|
# TODO: issue to diffusers?
|
|
|
|
# undo internal counter increment done by scheduler.step, so timestep can be resolved as before call
|
|
|
|
# this needed to be able call scheduler.add_noise with current timestep
|
|
|
|
if self.scheduler.order == 2:
|
|
|
|
self.scheduler._index_counter[timestep.item()] -= 1
|
|
|
|
|
2023-02-28 05:37:13 +00:00
|
|
|
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
|
|
|
|
# But the way things are now, scheduler runs _after_ that, so there was
|
|
|
|
# no way to use it to apply an operation that happens after the last scheduler.step.
|
|
|
|
for guidance in additional_guidance:
|
|
|
|
step_output = guidance(step_output, timestep, conditioning_data)
|
|
|
|
|
2023-08-29 23:18:08 +00:00
|
|
|
# restore internal counter
|
|
|
|
if self.scheduler.order == 2:
|
|
|
|
self.scheduler._index_counter[timestep.item()] += 1
|
|
|
|
|
2023-02-28 05:37:13 +00:00
|
|
|
return step_output
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
def _unet_forward(
|
|
|
|
self,
|
|
|
|
latents,
|
|
|
|
t,
|
|
|
|
text_embeddings,
|
|
|
|
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
2023-04-29 07:43:21 +00:00
|
|
|
**kwargs,
|
2023-03-03 06:02:00 +00:00
|
|
|
):
|
2023-02-28 05:37:13 +00:00
|
|
|
"""predict the noise residual"""
|
|
|
|
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
|
|
|
# Pad out normal non-inpainting inputs for an inpainting model.
|
|
|
|
# FIXME: There are too many layers of functions and we have too many different ways of
|
|
|
|
# overriding things! This should get handled in a way more consistent with the other
|
|
|
|
# use of AddsMaskLatents.
|
|
|
|
latents = AddsMaskLatents(
|
|
|
|
self._unet_forward,
|
2023-03-03 06:02:00 +00:00
|
|
|
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
|
|
|
|
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype),
|
2023-02-28 05:37:13 +00:00
|
|
|
).add_mask_channels(latents)
|
|
|
|
|
|
|
|
# First three args should be positional, not keywords, so torch hooks can see them.
|
2023-03-03 06:02:00 +00:00
|
|
|
return self.unet(
|
2023-04-29 07:43:21 +00:00
|
|
|
latents,
|
|
|
|
t,
|
|
|
|
text_embeddings,
|
|
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
|
|
**kwargs,
|
2023-03-03 06:02:00 +00:00
|
|
|
).sample
|