InvokeAI/invokeai/backend/stable_diffusion/diffusers_pipeline.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

1010 lines
42 KiB
Python
Raw Normal View History

2023-02-28 05:37:13 +00:00
from __future__ import annotations
import dataclasses
import inspect
import math
2023-02-28 05:37:13 +00:00
import secrets
from collections.abc import Sequence
from dataclasses import dataclass, field
2023-03-03 06:02:00 +00:00
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel, Field
2023-02-28 05:37:13 +00:00
import einops
2023-03-03 06:02:00 +00:00
import PIL.Image
import numpy as np
2023-03-13 13:11:09 +00:00
from accelerate.utils import set_seed
2023-02-28 05:37:13 +00:00
import psutil
import torch
import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
2023-02-28 05:37:13 +00:00
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
2023-03-03 06:02:00 +00:00
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline,
)
from diffusers.pipelines.controlnet import MultiControlNetModel
2023-03-03 06:02:00 +00:00
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
2023-02-28 05:37:13 +00:00
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils import PIL_INTERPOLATION
2023-02-28 05:37:13 +00:00
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec
from invokeai.app.services.config import InvokeAIAppConfig
2023-03-03 06:02:00 +00:00
from ..util import CPU_DEVICE, normalize_device
from .diffusion import (
AttentionMapSaver,
InvokeAIDiffuserComponent,
PostprocessingSettings,
)
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
2023-03-03 06:02:00 +00:00
2023-02-28 05:37:13 +00:00
@dataclass
class PipelineIntermediateState:
run_id: str
step: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None
attention_map_saver: Optional[AttentionMapSaver] = None
@dataclass
class AddsMaskLatents:
"""Add the channels required for inpainting model input.
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
and the latent encoding of the base image.
This class assumes the same mask and base image should apply to all items in the batch.
"""
2023-03-03 06:02:00 +00:00
2023-02-28 05:37:13 +00:00
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
mask: torch.Tensor
initial_image_latents: torch.Tensor
2023-03-03 06:02:00 +00:00
def __call__(
self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor, **kwargs,
2023-03-03 06:02:00 +00:00
) -> torch.Tensor:
2023-02-28 05:37:13 +00:00
model_input = self.add_mask_channels(latents)
return self.forward(model_input, t, text_embeddings, **kwargs)
2023-02-28 05:37:13 +00:00
def add_mask_channels(self, latents):
batch_size = latents.size(0)
# duplicate mask and latents for each batch
2023-03-03 06:02:00 +00:00
mask = einops.repeat(
self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size
)
image_latents = einops.repeat(
self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size
)
2023-02-28 05:37:13 +00:00
# add mask and image as additional channels
2023-03-03 06:02:00 +00:00
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
2023-02-28 05:37:13 +00:00
return model_input
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
2023-03-03 06:02:00 +00:00
return isinstance(b, torch.Tensor) and (a.size() == b.size())
2023-02-28 05:37:13 +00:00
@dataclass
class AddsMaskGuidance:
mask: torch.FloatTensor
mask_latents: torch.FloatTensor
scheduler: SchedulerMixin
noise: torch.Tensor
_debug: Optional[Callable] = None
2023-03-03 06:02:00 +00:00
def __call__(
self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning
2023-03-03 06:02:00 +00:00
) -> BaseOutput:
2023-02-28 05:37:13 +00:00
output_class = step_output.__class__ # We'll create a new one with masked data.
# The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
# It's reasonable to assume the first thing is prev_sample, but then does it have other things
# like pred_original_sample? Should we apply the mask to them too?
# But what if there's just some other random field?
prev_sample = step_output[0]
# Mask anything that has the same shape as prev_sample, return others as-is.
return output_class(
2023-03-03 06:02:00 +00:00
{
k: (
self.apply_mask(v, self._t_for_field(k, t))
if are_like_tensors(prev_sample, v)
else v
)
for k, v in step_output.items()
}
2023-02-28 05:37:13 +00:00
)
2023-03-03 06:02:00 +00:00
def _t_for_field(self, field_name: str, t):
2023-02-28 05:37:13 +00:00
if field_name == "pred_original_sample":
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0
return t
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
batch_size = latents.size(0)
2023-03-03 06:02:00 +00:00
mask = einops.repeat(
self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size
)
2023-02-28 05:37:13 +00:00
if t.dim() == 0:
# some schedulers expect t to be one-dimensional.
# TODO: file diffusers bug about inconsistency?
2023-03-03 06:02:00 +00:00
t = einops.repeat(t, "-> batch", batch=batch_size)
2023-02-28 05:37:13 +00:00
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
# get very confused about what is happening from step to step when we do that.
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
2023-03-03 06:02:00 +00:00
mask_latents = einops.repeat(
mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size
)
masked_input = torch.lerp(
mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)
)
2023-02-28 05:37:13 +00:00
if self._debug:
self._debug(masked_input, f"t={t} lerped")
return masked_input
def trim_to_multiple_of(*args, multiple_of=8):
return tuple((x - x % multiple_of) for x in args)
2023-03-03 06:02:00 +00:00
def image_resized_to_grid_as_tensor(
image: PIL.Image.Image, normalize: bool = True, multiple_of=8
) -> torch.FloatTensor:
2023-02-28 05:37:13 +00:00
"""
:param image: input image
:param normalize: scale the range to [-1, 1] instead of [0, 1]
:param multiple_of: resize the input so both dimensions are a multiple of this
"""
w, h = trim_to_multiple_of(*image.size, multiple_of=multiple_of)
2023-03-03 06:02:00 +00:00
transformation = T.Compose(
[
T.Resize((h, w), T.InterpolationMode.LANCZOS),
T.ToTensor(),
]
)
2023-02-28 05:37:13 +00:00
tensor = transformation(image)
if normalize:
tensor = tensor * 2.0 - 1.0
return tensor
def is_inpainting_model(unet: UNet2DConditionModel):
return unet.conv_in.in_channels == 9
2023-03-03 06:02:00 +00:00
CallbackType = TypeVar("CallbackType")
ReturnType = TypeVar("ReturnType")
ParamType = ParamSpec("ParamType")
2023-02-28 05:37:13 +00:00
@dataclass(frozen=True)
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
"""Convert a generator to a function with a callback and a return value."""
generator_method: Callable[ParamType, ReturnType]
callback_arg_type: Type[CallbackType]
2023-03-03 06:02:00 +00:00
def __call__(
self,
*args: ParamType.args,
callback: Callable[[CallbackType], Any] = None,
**kwargs: ParamType.kwargs,
) -> ReturnType:
2023-02-28 05:37:13 +00:00
result = None
for result in self.generator_method(*args, **kwargs):
if callback is not None and isinstance(result, self.callback_arg_type):
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
@dataclass
class ControlNetData:
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor = Field(default=None)
weight: Union[float, List[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
control_mode: str = Field(default="balanced")
2023-02-28 05:37:13 +00:00
@dataclass(frozen=True)
class ConditioningData:
unconditioned_embeddings: torch.Tensor
text_embeddings: torch.Tensor
Feat/easy param (#3504) * Testing change to LatentsToText to allow setting different cfg_scale values per diffusion step. * Adding first attempt at float param easing node, using Penner easing functions. * Core implementation of ControlNet and MultiControlNet. * Added support for ControlNet and MultiControlNet to legacy non-nodal Txt2Img in backend/generator. Although backend/generator will likely disappear by v3.x, right now they are very useful for testing core ControlNet and MultiControlNet functionality while node codebase is rapidly evolving. * Added example of using ControlNet with legacy Txt2Img generator * Resolving rebase conflict * Added first controlnet preprocessor node for canny edge detection. * Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node * Switching to ControlField for output from controlnet nodes. * Resolving conflicts in rebase to origin/main * Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke()) * changes to base class for controlnet nodes * Added HED, LineArt, and OpenPose ControlNet nodes * Added an additional "raw_processed_image" output port to controlnets, mainly so could route ImageField to a ShowImage node * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * More rebase repair. * Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this... * Fixed use of ControlNet control_weight parameter * Fixed lint-ish formatting error * Core implementation of ControlNet and MultiControlNet. * Added first controlnet preprocessor node for canny edge detection. * Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node * Switching to ControlField for output from controlnet nodes. * Refactored controlnet node to output ControlField that bundles control info. * changes to base class for controlnet nodes * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * Cleaning up TextToLatent arg testing * Cleaning up mistakes after rebase. * Removed last bits of dtype and and device hardwiring from controlnet section * Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled. * Added support for specifying which step iteration to start using each ControlNet, and which step to end using each controlnet (specified as fraction of total steps) * Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input. * Added dependency on controlnet-aux v0.0.3 * Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it. * Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names. * Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle. * Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier. * Cleaning up after ControlNet refactor in TextToLatentsInvocation * Extended node-based ControlNet support to LatentsToLatentsInvocation. * chore(ui): regen api client * fix(ui): add value to conditioning field * fix(ui): add control field type * fix(ui): fix node ui type hints * fix(nodes): controlnet input accepts list or single controlnet * Moved to controlnet_aux v0.0.4, reinstated Zoe controlnet preprocessor. Also in pyproject.toml had to specify downgrade of timm to 0.6.13 _after_ controlnet-aux installs timm >= 0.9.2, because timm >0.6.13 breaks Zoe preprocessor. * Core implementation of ControlNet and MultiControlNet. * Added first controlnet preprocessor node for canny edge detection. * Switching to ControlField for output from controlnet nodes. * Resolving conflicts in rebase to origin/main * Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke()) * changes to base class for controlnet nodes * Added HED, LineArt, and OpenPose ControlNet nodes * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this... * Fixed use of ControlNet control_weight parameter * Core implementation of ControlNet and MultiControlNet. * Added first controlnet preprocessor node for canny edge detection. * Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node * Switching to ControlField for output from controlnet nodes. * Refactored controlnet node to output ControlField that bundles control info. * changes to base class for controlnet nodes * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * Cleaning up TextToLatent arg testing * Cleaning up mistakes after rebase. * Removed last bits of dtype and and device hardwiring from controlnet section * Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled. * Added support for specifying which step iteration to start using each ControlNet, and which step to end using each controlnet (specified as fraction of total steps) * Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input. * Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it. * Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names. * Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle. * Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier. * Cleaning up after ControlNet refactor in TextToLatentsInvocation * Extended node-based ControlNet support to LatentsToLatentsInvocation. * chore(ui): regen api client * fix(ui): fix node ui type hints * fix(nodes): controlnet input accepts list or single controlnet * Added Mediapipe image processor for use as ControlNet preprocessor. Also hacked in ability to specify HF subfolder when loading ControlNet models from string. * Fixed bug where MediapipFaceProcessorInvocation was ignoring max_faces and min_confidence params. * Added nodes for float params: ParamFloatInvocation and FloatCollectionOutput. Also added FloatOutput. * Added mediapipe install requirement. Should be able to remove once controlnet_aux package adds mediapipe to its requirements. * Added float to FIELD_TYPE_MAP ins constants.ts * Progress toward improvement in fieldTemplateBuilder.ts getFieldType() * Fixed controlnet preprocessors and controlnet handling in TextToLatents to work with revised Image services. * Cleaning up from merge, re-adding cfg_scale to FIELD_TYPE_MAP * Making sure cfg_scale of type list[float] can be used in image metadata, to support param easing for cfg_scale * Fixed math for per-step param easing. * Added option to show plot of param value at each step * Just cleaning up after adding param easing plot option, removing vestigial code. * Modified control_weight ControlNet param to be polistmorphic -- can now be either a single float weight applied for all steps, or a list of floats of size total_steps, that specifies weight for each step. * Added more informative error message when _validat_edge() throws an error. * Just improving parm easing bar chart title to include easing type. * Added requirement for easing-functions package * Taking out some diagnostic prints. * Added option to use both easing function and mirror of easing function together. * Fixed recently introduced problem (when pulled in main), triggered by num_steps in StepParamEasingInvocation not having a default value -- just added default. --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-06-11 06:27:44 +00:00
guidance_scale: Union[float, List[float]]
2023-02-28 05:37:13 +00:00
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
@property
def dtype(self):
return self.text_embeddings.dtype
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
scheduler_args = dict(self.scheduler_args)
step_method = inspect.signature(scheduler.step)
for name, value in kwargs.items():
try:
step_method.bind_partial(**{name: value})
except TypeError:
# FIXME: don't silently discard arguments
pass # debug("%s does not accept argument named %r", scheduler, name)
else:
scheduler_args[name] = value
return dataclasses.replace(self, scheduler_args=scheduler_args)
2023-03-03 06:02:00 +00:00
2023-02-28 05:37:13 +00:00
@dataclass
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r"""
Output class for InvokeAI's Stable Diffusion pipeline.
Args:
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
after generation completes. Optional.
"""
attention_map_saver: Optional[AttentionMapSaver]
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline.
Hopefully future versions of diffusers provide access to more of these functions so that we don't
need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
2023-02-28 05:37:13 +00:00
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
2023-02-28 05:37:13 +00:00
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_model_group: ModelGroup
ID_LENGTH = 8
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
2023-03-03 06:02:00 +00:00
precision: str = "float32",
control_model: ControlNetModel = None,
execution_device: Optional[torch.device] = None,
2023-02-28 05:37:13 +00:00
):
2023-03-03 06:02:00 +00:00
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
requires_safety_checker,
)
2023-02-28 05:37:13 +00:00
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
# FIXME: can't currently register control module
# control_model=control_model,
2023-02-28 05:37:13 +00:00
)
2023-03-03 06:02:00 +00:00
self.invokeai_diffuser = InvokeAIDiffuserComponent(
2023-06-17 16:39:51 +00:00
self.unet, self._unet_forward
2023-02-28 05:37:13 +00:00
)
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
2023-02-28 05:37:13 +00:00
self._model_group.install(*self._submodels)
self.control_model = control_model
2023-02-28 05:37:13 +00:00
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
config = InvokeAIAppConfig.get_config()
2023-03-03 06:02:00 +00:00
if (
torch.cuda.is_available()
and is_xformers_available()
and not config.disable_xformers
2023-03-03 06:02:00 +00:00
):
2023-02-28 05:37:13 +00:00
self.enable_xformers_memory_efficient_attention()
else:
if self.device.type == "cpu" or self.device.type == "mps":
mem_free = psutil.virtual_memory().free
elif self.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
2023-02-28 05:37:13 +00:00
else:
raise ValueError(f"unrecognized device {self.device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = (
latents.element_size() + 4
)
max_size_required_for_baddbmm = (
16
* latents.size(dim=2)
* latents.size(dim=3)
* latents.size(dim=2)
* latents.size(dim=3)
* bytes_per_element_needed_for_baddbmm_duplication
)
if max_size_required_for_baddbmm > (
mem_free * 3.0 / 4.0
): # 3.3 / 4.0 is from old Invoke code
self.enable_attention_slicing(slice_size="max")
elif torch.backends.mps.is_available():
# diffusers recommends always enabling for mps
self.enable_attention_slicing(slice_size="max")
else:
self.disable_attention_slicing()
2023-02-28 05:37:13 +00:00
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
2023-02-28 05:37:13 +00:00
# overridden method; types match the superclass.
if torch_device is None:
return self
self._model_group.set_device(torch.device(torch_device))
self._model_group.ready()
@property
def device(self) -> torch.device:
return self._model_group.execution_device
@property
def _submodels(self) -> Sequence[torch.nn.Module]:
module_names, _, _ = self.extract_init_dict(dict(self.config))
submodels = []
for name in module_names.keys():
if hasattr(self, name):
value = getattr(self, name)
else:
value = getattr(self.config, name)
if isinstance(value, torch.nn.Module):
submodels.append(value)
return submodels
2023-02-28 05:37:13 +00:00
2023-03-03 06:02:00 +00:00
def image_from_embeddings(
self,
latents: torch.Tensor,
num_inference_steps: int,
conditioning_data: ConditioningData,
*,
noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
**kwargs,
2023-03-03 06:02:00 +00:00
) -> InvokeAIStableDiffusionPipelineOutput:
2023-02-28 05:37:13 +00:00
r"""
Function invoked when calling the pipeline for generation.
:param conditioning_data:
:param latents: Pre-generated un-noised latents, to be used as inputs for
image generation. Can be used to tweak the same generation with different prompts.
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
image at the expense of slower inference.
:param noise: Noise to add to the latents, sampled from a Gaussian distribution.
:param callback:
:param run_id:
"""
result_latents, result_attention_map_saver = self.latents_from_embeddings(
2023-03-03 06:02:00 +00:00
latents,
num_inference_steps,
2023-02-28 05:37:13 +00:00
conditioning_data,
noise=noise,
run_id=run_id,
2023-03-03 06:02:00 +00:00
callback=callback,
**kwargs,
2023-03-03 06:02:00 +00:00
)
2023-02-28 05:37:13 +00:00
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
2023-03-03 06:02:00 +00:00
output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_map_saver,
)
2023-02-28 05:37:13 +00:00
return self.check_for_safety(output, dtype=conditioning_data.dtype)
2023-03-03 06:02:00 +00:00
def latents_from_embeddings(
self,
latents: torch.Tensor,
num_inference_steps: int,
conditioning_data: ConditioningData,
*,
noise: torch.Tensor,
timesteps=None,
additional_guidance: List[Callable] = None,
run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
**kwargs,
2023-03-03 06:02:00 +00:00
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
2023-05-11 08:52:37 +00:00
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
else:
scheduler_device = self._model_group.device_for(self.unet)
2023-02-28 05:37:13 +00:00
if timesteps is None:
2023-05-11 08:52:37 +00:00
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
2023-02-28 05:37:13 +00:00
timesteps = self.scheduler.timesteps
2023-03-03 06:02:00 +00:00
infer_latents_from_embeddings = GeneratorToCallbackinator(
self.generate_latents_from_embeddings, PipelineIntermediateState
)
2023-02-28 05:37:13 +00:00
result: PipelineIntermediateState = infer_latents_from_embeddings(
2023-03-03 06:02:00 +00:00
latents,
timesteps,
conditioning_data,
2023-02-28 05:37:13 +00:00
noise=noise,
additional_guidance=additional_guidance,
run_id=run_id,
2023-03-03 06:02:00 +00:00
callback=callback,
control_data=control_data,
**kwargs,
2023-03-03 06:02:00 +00:00
)
2023-02-28 05:37:13 +00:00
return result.latents, result.attention_map_saver
2023-03-03 06:02:00 +00:00
def generate_latents_from_embeddings(
self,
latents: torch.Tensor,
timesteps,
conditioning_data: ConditioningData,
*,
noise: torch.Tensor,
run_id: str = None,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs,
2023-03-03 06:02:00 +00:00
):
2023-02-28 05:37:13 +00:00
self._adjust_memory_efficient_attention(latents)
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None:
additional_guidance = []
extra_conditioning_info = conditioning_data.extra
2023-03-03 06:02:00 +00:00
with self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
2023-03-03 06:02:00 +00:00
):
yield PipelineIntermediateState(
run_id=run_id,
step=-1,
timestep=self.scheduler.config.num_train_timesteps,
2023-03-03 06:02:00 +00:00
latents=latents,
)
2023-02-28 05:37:13 +00:00
batch_size = latents.shape[0]
2023-03-03 06:02:00 +00:00
batched_t = torch.full(
(batch_size,),
timesteps[0],
dtype=timesteps.dtype,
device=self._model_group.device_for(self.unet),
)
2023-02-28 05:37:13 +00:00
latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None
# print("timesteps:", timesteps)
2023-02-28 05:37:13 +00:00
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
2023-03-03 06:02:00 +00:00
step_output = self.step(
batched_t,
latents,
conditioning_data,
step_index=i,
total_step_count=len(timesteps),
additional_guidance=additional_guidance,
control_data=control_data,
**kwargs,
2023-03-03 06:02:00 +00:00
)
2023-02-28 05:37:13 +00:00
latents = step_output.prev_sample
latents = self.invokeai_diffuser.do_latent_postprocessing(
postprocessing_settings=conditioning_data.postprocessing_settings,
latents=latents,
sigma=batched_t,
step_index=i,
2023-03-03 06:02:00 +00:00
total_step_count=len(timesteps),
2023-02-28 05:37:13 +00:00
)
2023-03-03 06:02:00 +00:00
predicted_original = getattr(step_output, "pred_original_sample", None)
2023-02-28 05:37:13 +00:00
# TODO resuscitate attention map saving
2023-03-03 06:02:00 +00:00
# if i == len(timesteps)-1 and extra_conditioning_info is not None:
2023-02-28 05:37:13 +00:00
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
# attention_map_token_ids = range(1, eos_token_index)
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
2023-03-03 06:02:00 +00:00
yield PipelineIntermediateState(
run_id=run_id,
step=i,
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
attention_map_saver=attention_map_saver,
)
2023-02-28 05:37:13 +00:00
return latents, attention_map_saver
@torch.inference_mode()
2023-03-03 06:02:00 +00:00
def step(
self,
t: torch.Tensor,
latents: torch.Tensor,
conditioning_data: ConditioningData,
step_index: int,
total_step_count: int,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs,
2023-03-03 06:02:00 +00:00
):
2023-02-28 05:37:13 +00:00
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
if additional_guidance is None:
additional_guidance = []
# TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent
unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
# default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None:
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
control_mode = control_datum.control_mode
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
# that are combined at higher level to make control_mode enum
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
# or default weighting (if False)
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
# or the default both conditional and unconditional (if False)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
if cfg_injection:
control_latent_input = unet_latent_input
else:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
control_latent_input = torch.cat([unet_latent_input] * 2)
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings])
else:
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings])
Feat/easy param (#3504) * Testing change to LatentsToText to allow setting different cfg_scale values per diffusion step. * Adding first attempt at float param easing node, using Penner easing functions. * Core implementation of ControlNet and MultiControlNet. * Added support for ControlNet and MultiControlNet to legacy non-nodal Txt2Img in backend/generator. Although backend/generator will likely disappear by v3.x, right now they are very useful for testing core ControlNet and MultiControlNet functionality while node codebase is rapidly evolving. * Added example of using ControlNet with legacy Txt2Img generator * Resolving rebase conflict * Added first controlnet preprocessor node for canny edge detection. * Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node * Switching to ControlField for output from controlnet nodes. * Resolving conflicts in rebase to origin/main * Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke()) * changes to base class for controlnet nodes * Added HED, LineArt, and OpenPose ControlNet nodes * Added an additional "raw_processed_image" output port to controlnets, mainly so could route ImageField to a ShowImage node * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * More rebase repair. * Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this... * Fixed use of ControlNet control_weight parameter * Fixed lint-ish formatting error * Core implementation of ControlNet and MultiControlNet. * Added first controlnet preprocessor node for canny edge detection. * Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node * Switching to ControlField for output from controlnet nodes. * Refactored controlnet node to output ControlField that bundles control info. * changes to base class for controlnet nodes * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * Cleaning up TextToLatent arg testing * Cleaning up mistakes after rebase. * Removed last bits of dtype and and device hardwiring from controlnet section * Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled. * Added support for specifying which step iteration to start using each ControlNet, and which step to end using each controlnet (specified as fraction of total steps) * Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input. * Added dependency on controlnet-aux v0.0.3 * Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it. * Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names. * Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle. * Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier. * Cleaning up after ControlNet refactor in TextToLatentsInvocation * Extended node-based ControlNet support to LatentsToLatentsInvocation. * chore(ui): regen api client * fix(ui): add value to conditioning field * fix(ui): add control field type * fix(ui): fix node ui type hints * fix(nodes): controlnet input accepts list or single controlnet * Moved to controlnet_aux v0.0.4, reinstated Zoe controlnet preprocessor. Also in pyproject.toml had to specify downgrade of timm to 0.6.13 _after_ controlnet-aux installs timm >= 0.9.2, because timm >0.6.13 breaks Zoe preprocessor. * Core implementation of ControlNet and MultiControlNet. * Added first controlnet preprocessor node for canny edge detection. * Switching to ControlField for output from controlnet nodes. * Resolving conflicts in rebase to origin/main * Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke()) * changes to base class for controlnet nodes * Added HED, LineArt, and OpenPose ControlNet nodes * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this... * Fixed use of ControlNet control_weight parameter * Core implementation of ControlNet and MultiControlNet. * Added first controlnet preprocessor node for canny edge detection. * Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node * Switching to ControlField for output from controlnet nodes. * Refactored controlnet node to output ControlField that bundles control info. * changes to base class for controlnet nodes * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * Cleaning up TextToLatent arg testing * Cleaning up mistakes after rebase. * Removed last bits of dtype and and device hardwiring from controlnet section * Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled. * Added support for specifying which step iteration to start using each ControlNet, and which step to end using each controlnet (specified as fraction of total steps) * Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input. * Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it. * Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names. * Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle. * Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier. * Cleaning up after ControlNet refactor in TextToLatentsInvocation * Extended node-based ControlNet support to LatentsToLatentsInvocation. * chore(ui): regen api client * fix(ui): fix node ui type hints * fix(nodes): controlnet input accepts list or single controlnet * Added Mediapipe image processor for use as ControlNet preprocessor. Also hacked in ability to specify HF subfolder when loading ControlNet models from string. * Fixed bug where MediapipFaceProcessorInvocation was ignoring max_faces and min_confidence params. * Added nodes for float params: ParamFloatInvocation and FloatCollectionOutput. Also added FloatOutput. * Added mediapipe install requirement. Should be able to remove once controlnet_aux package adds mediapipe to its requirements. * Added float to FIELD_TYPE_MAP ins constants.ts * Progress toward improvement in fieldTemplateBuilder.ts getFieldType() * Fixed controlnet preprocessors and controlnet handling in TextToLatents to work with revised Image services. * Cleaning up from merge, re-adding cfg_scale to FIELD_TYPE_MAP * Making sure cfg_scale of type list[float] can be used in image metadata, to support param easing for cfg_scale * Fixed math for per-step param easing. * Added option to show plot of param value at each step * Just cleaning up after adding param easing plot option, removing vestigial code. * Modified control_weight ControlNet param to be polistmorphic -- can now be either a single float weight applied for all steps, or a list of floats of size total_steps, that specifies weight for each step. * Added more informative error message when _validat_edge() throws an error. * Just improving parm easing bar chart title to include easing type. * Added requirement for easing-functions package * Taking out some diagnostic prints. * Added option to use both easing function and mirror of easing function together. * Fixed recently introduced problem (when pulled in main), triggered by num_steps in StepParamEasingInvocation not having a default value -- just added default. --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-06-11 06:27:44 +00:00
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight
# controlnet(s) inference
down_samples, mid_sample = control_datum.model(
sample=control_latent_input,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
if cfg_injection:
# Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
2023-02-28 05:37:13 +00:00
# predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step(
x=unet_latent_input,
sigma=t,
unconditioning=conditioning_data.unconditioned_embeddings,
conditioning=conditioning_data.text_embeddings,
unconditional_guidance_scale=conditioning_data.guidance_scale,
2023-02-28 05:37:13 +00:00
step_index=step_index,
total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
2023-02-28 05:37:13 +00:00
)
# compute the previous noisy sample x_t -> x_t-1
2023-03-03 06:02:00 +00:00
step_output = self.scheduler.step(
noise_pred, timestep, latents, **conditioning_data.scheduler_args
)
2023-02-28 05:37:13 +00:00
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
# But the way things are now, scheduler runs _after_ that, so there was
# no way to use it to apply an operation that happens after the last scheduler.step.
for guidance in additional_guidance:
step_output = guidance(step_output, timestep, conditioning_data)
return step_output
2023-03-03 06:02:00 +00:00
def _unet_forward(
self,
latents,
t,
text_embeddings,
cross_attention_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
2023-03-03 06:02:00 +00:00
):
2023-02-28 05:37:13 +00:00
"""predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4:
# Pad out normal non-inpainting inputs for an inpainting model.
# FIXME: There are too many layers of functions and we have too many different ways of
# overriding things! This should get handled in a way more consistent with the other
# use of AddsMaskLatents.
latents = AddsMaskLatents(
self._unet_forward,
2023-03-03 06:02:00 +00:00
mask=torch.ones_like(
latents[:1, :1], device=latents.device, dtype=latents.dtype
),
initial_image_latents=torch.zeros_like(
latents[:1], device=latents.device, dtype=latents.dtype
),
2023-02-28 05:37:13 +00:00
).add_mask_channels(latents)
# First three args should be positional, not keywords, so torch hooks can see them.
2023-03-03 06:02:00 +00:00
return self.unet(
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
**kwargs,
2023-03-03 06:02:00 +00:00
).sample
def img2img_from_embeddings(
self,
init_image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float,
num_inference_steps: int,
conditioning_data: ConditioningData,
*,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
noise_func=None,
seed=None,
2023-03-03 06:02:00 +00:00
) -> InvokeAIStableDiffusionPipelineOutput:
2023-02-28 05:37:13 +00:00
if isinstance(init_image, PIL.Image.Image):
2023-03-03 06:02:00 +00:00
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
2023-02-28 05:37:13 +00:00
if init_image.dim() == 3:
2023-03-03 06:02:00 +00:00
init_image = einops.rearrange(init_image, "c h w -> 1 c h w")
2023-02-28 05:37:13 +00:00
# 6. Prepare latent variables
initial_latents = self.non_noised_latents_from_image(
2023-03-03 06:02:00 +00:00
init_image,
device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype,
)
2023-03-13 13:11:09 +00:00
if seed is not None:
set_seed(seed)
noise = noise_func(initial_latents)
2023-02-28 05:37:13 +00:00
2023-03-03 06:02:00 +00:00
return self.img2img_from_latents_and_embeddings(
initial_latents,
num_inference_steps,
conditioning_data,
strength,
noise,
run_id,
callback,
)
def img2img_from_latents_and_embeddings(
self,
initial_latents,
num_inference_steps,
conditioning_data: ConditioningData,
strength,
noise: torch.Tensor,
run_id=None,
callback=None,
) -> InvokeAIStableDiffusionPipelineOutput:
2023-05-11 08:52:37 +00:00
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
2023-02-28 05:37:13 +00:00
result_latents, result_attention_maps = self.latents_from_embeddings(
latents=initial_latents if strength < 1.0 else torch.zeros_like(
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
),
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
2023-02-28 05:37:13 +00:00
timesteps=timesteps,
noise=noise,
run_id=run_id,
2023-03-03 06:02:00 +00:00
callback=callback,
)
2023-02-28 05:37:13 +00:00
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
2023-03-03 06:02:00 +00:00
output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_maps,
)
2023-02-28 05:37:13 +00:00
return self.check_for_safety(output, dtype=conditioning_data.dtype)
2023-03-03 06:02:00 +00:00
def get_img2img_timesteps(
2023-05-11 08:52:37 +00:00
self, num_inference_steps: int, strength: float, device=None
2023-03-03 06:02:00 +00:00
) -> (torch.Tensor, int):
2023-02-28 05:37:13 +00:00
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
assert img2img_pipeline.scheduler is self.scheduler
2023-05-11 08:52:37 +00:00
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
else:
scheduler_device = self._model_group.device_for(self.unet)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
2023-03-03 06:02:00 +00:00
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
2023-05-11 08:52:37 +00:00
num_inference_steps, strength, device=scheduler_device
2023-03-03 06:02:00 +00:00
)
2023-02-28 05:37:13 +00:00
# Workaround for low strength resulting in zero timesteps.
# TODO: submit upstream fix for zero-step img2img
if timesteps.numel() == 0:
timesteps = self.scheduler.timesteps[-1:]
adjusted_steps = timesteps.numel()
return timesteps, adjusted_steps
def inpaint_from_embeddings(
2023-03-03 06:02:00 +00:00
self,
init_image: torch.FloatTensor,
mask: torch.FloatTensor,
strength: float,
num_inference_steps: int,
conditioning_data: ConditioningData,
*,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
noise_func=None,
seed=None,
2023-03-03 06:02:00 +00:00
) -> InvokeAIStableDiffusionPipelineOutput:
2023-02-28 05:37:13 +00:00
device = self._model_group.device_for(self.unet)
latents_dtype = self.unet.dtype
if isinstance(init_image, PIL.Image.Image):
2023-03-03 06:02:00 +00:00
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
2023-02-28 05:37:13 +00:00
init_image = init_image.to(device=device, dtype=latents_dtype)
mask = mask.to(device=device, dtype=latents_dtype)
if init_image.dim() == 3:
init_image = init_image.unsqueeze(0)
2023-05-11 08:52:37 +00:00
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
2023-02-28 05:37:13 +00:00
# 6. Prepare latent variables
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
# because we have our own noise function
2023-03-03 06:02:00 +00:00
init_image_latents = self.non_noised_latents_from_image(
init_image, device=device, dtype=latents_dtype
)
2023-03-13 13:11:09 +00:00
if seed is not None:
set_seed(seed)
noise = noise_func(init_image_latents)
2023-02-28 05:37:13 +00:00
if mask.dim() == 3:
mask = mask.unsqueeze(0)
2023-03-03 06:02:00 +00:00
latent_mask = tv_resize(
mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR
).to(device=device, dtype=latents_dtype)
2023-02-28 05:37:13 +00:00
guidance: List[Callable] = []
if is_inpainting_model(self.unet):
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
# (that's why there's a mask!) but it seems to really want that blanked out.
masked_init_image = init_image * torch.where(mask < 0.5, 1, 0)
2023-03-03 06:02:00 +00:00
masked_latents = self.non_noised_latents_from_image(
masked_init_image, device=device, dtype=latents_dtype
)
2023-02-28 05:37:13 +00:00
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
2023-03-03 06:02:00 +00:00
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
self._unet_forward, latent_mask, masked_latents
)
2023-02-28 05:37:13 +00:00
else:
2023-03-03 06:02:00 +00:00
guidance.append(
AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise)
)
2023-02-28 05:37:13 +00:00
try:
result_latents, result_attention_maps = self.latents_from_embeddings(
latents=init_image_latents if strength < 1.0 else torch.zeros_like(
init_image_latents, device=init_image_latents.device, dtype=init_image_latents.dtype
),
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
2023-03-03 06:02:00 +00:00
noise=noise,
timesteps=timesteps,
2023-02-28 05:37:13 +00:00
additional_guidance=guidance,
2023-03-03 06:02:00 +00:00
run_id=run_id,
callback=callback,
)
2023-02-28 05:37:13 +00:00
finally:
self.invokeai_diffuser.model_forward_callback = self._unet_forward
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
2023-03-03 06:02:00 +00:00
output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_maps,
)
2023-02-28 05:37:13 +00:00
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
init_image = init_image.to(device=device, dtype=dtype)
with torch.inference_mode():
self._model_group.load(self.vae)
2023-02-28 05:37:13 +00:00
init_latent_dist = self.vae.encode(init_image).latent_dist
2023-03-03 06:02:00 +00:00
init_latents = init_latent_dist.sample().to(
dtype=dtype
) # FIXME: uses torch.randn. make reproducible!
2023-02-28 05:37:13 +00:00
init_latents = 0.18215 * init_latents
return init_latents
def check_for_safety(self, output, dtype):
with torch.inference_mode():
2023-03-03 06:02:00 +00:00
screened_images, has_nsfw_concept = self.run_safety_checker(
output.images, dtype=dtype
)
2023-02-28 05:37:13 +00:00
screened_attention_map_saver = None
if has_nsfw_concept is None or not has_nsfw_concept:
screened_attention_map_saver = output.attention_map_saver
2023-03-03 06:02:00 +00:00
return InvokeAIStableDiffusionPipelineOutput(
screened_images,
has_nsfw_concept,
# block the attention maps if NSFW content is detected
attention_map_saver=screened_attention_map_saver,
)
2023-02-28 05:37:13 +00:00
def run_safety_checker(self, image, device=None, dtype=None):
# overriding to use the model group for device info instead of requiring the caller to know.
if self.safety_checker is not None:
device = self._model_group.device_for(self.safety_checker)
return super().run_safety_checker(image, device, dtype)
def decode_latents(self, latents):
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
self._model_group.load(self.vae)
return super().decode_latents(latents)
def debug_latents(self, latents, msg):
from invokeai.backend.image_util import debug_image
2023-02-28 05:37:13 +00:00
with torch.inference_mode():
decoded = self.numpy_to_pil(self.decode_latents(latents))
for i, img in enumerate(decoded):
debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
)
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
# Returns torch.Tensor of shape (batch_size, 3, height, width)
2023-06-01 22:09:49 +00:00
@staticmethod
def prepare_control_image(
image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
# latents,
width=512, # should be 8 * latent.shape[3]
height=512, # should be 8 * latent height[2]
batch_size=1,
num_images_per_prompt=1,
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced"
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
if do_classifier_free_guidance and not cfg_injection:
image = torch.cat([image] * 2)
return image