mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
1302 lines
55 KiB
Python
1302 lines
55 KiB
Python
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
import inspect
|
|
from contextlib import ExitStack
|
|
from functools import singledispatchmethod
|
|
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
|
|
|
|
import einops
|
|
import numpy as np
|
|
import torch
|
|
import torchvision
|
|
import torchvision.transforms as T
|
|
from diffusers.configuration_utils import ConfigMixin
|
|
from diffusers.image_processor import VaeImageProcessor
|
|
from diffusers.models.adapter import T2IAdapter
|
|
from diffusers.models.attention_processor import (
|
|
AttnProcessor2_0,
|
|
LoRAAttnProcessor2_0,
|
|
LoRAXFormersAttnProcessor,
|
|
XFormersAttnProcessor,
|
|
)
|
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
|
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
|
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
|
|
from diffusers.schedulers.scheduling_tcd import TCDScheduler
|
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
|
|
from PIL import Image, ImageFilter
|
|
from pydantic import field_validator
|
|
from torchvision.transforms.functional import resize as tv_resize
|
|
from transformers import CLIPVisionModelWithProjection
|
|
|
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
|
from invokeai.app.invocations.fields import (
|
|
ConditioningField,
|
|
DenoiseMaskField,
|
|
FieldDescriptions,
|
|
ImageField,
|
|
Input,
|
|
InputField,
|
|
LatentsField,
|
|
OutputField,
|
|
UIType,
|
|
WithBoard,
|
|
WithMetadata,
|
|
)
|
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
|
from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput
|
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
|
from invokeai.backend.lora import LoRAModelRaw
|
|
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
|
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
|
|
from invokeai.backend.model_patcher import ModelPatcher
|
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|
BasicConditioningInfo,
|
|
IPAdapterConditioningInfo,
|
|
IPAdapterData,
|
|
Range,
|
|
SDXLConditioningInfo,
|
|
TextConditioningData,
|
|
TextConditioningRegions,
|
|
)
|
|
from invokeai.backend.util.mask import to_standard_float_mask
|
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
|
|
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
|
ControlNetData,
|
|
StableDiffusionGeneratorPipeline,
|
|
T2IAdapterData,
|
|
image_resized_to_grid_as_tensor,
|
|
)
|
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
|
from ...backend.util.devices import TorchDevice
|
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
|
from .controlnet_image_processors import ControlField
|
|
from .model import ModelIdentifierField, UNetField, VAEField
|
|
|
|
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
|
|
|
|
|
|
@invocation_output("scheduler_output")
|
|
class SchedulerOutput(BaseInvocationOutput):
|
|
scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
|
|
|
|
|
@invocation(
|
|
"scheduler",
|
|
title="Scheduler",
|
|
tags=["scheduler"],
|
|
category="latents",
|
|
version="1.0.0",
|
|
)
|
|
class SchedulerInvocation(BaseInvocation):
|
|
"""Selects a scheduler."""
|
|
|
|
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
|
default="euler",
|
|
description=FieldDescriptions.scheduler,
|
|
ui_type=UIType.Scheduler,
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> SchedulerOutput:
|
|
return SchedulerOutput(scheduler=self.scheduler)
|
|
|
|
|
|
@invocation(
|
|
"create_denoise_mask",
|
|
title="Create Denoise Mask",
|
|
tags=["mask", "denoise"],
|
|
category="latents",
|
|
version="1.0.2",
|
|
)
|
|
class CreateDenoiseMaskInvocation(BaseInvocation):
|
|
"""Creates mask for denoising model run."""
|
|
|
|
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
|
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
|
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
|
fp32: bool = InputField(
|
|
default=DEFAULT_PRECISION == "float32",
|
|
description=FieldDescriptions.fp32,
|
|
ui_order=4,
|
|
)
|
|
|
|
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
|
|
if mask_image.mode != "L":
|
|
mask_image = mask_image.convert("L")
|
|
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
|
if mask_tensor.dim() == 3:
|
|
mask_tensor = mask_tensor.unsqueeze(0)
|
|
# if shape is not None:
|
|
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
|
|
return mask_tensor
|
|
|
|
@torch.no_grad()
|
|
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
|
if self.image is not None:
|
|
image = context.images.get_pil(self.image.image_name)
|
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
|
if image_tensor.dim() == 3:
|
|
image_tensor = image_tensor.unsqueeze(0)
|
|
else:
|
|
image_tensor = None
|
|
|
|
mask = self.prep_mask_tensor(
|
|
context.images.get_pil(self.mask.image_name),
|
|
)
|
|
|
|
if image_tensor is not None:
|
|
vae_info = context.models.load(self.vae.vae)
|
|
|
|
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
|
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
|
# TODO:
|
|
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
|
|
|
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
|
else:
|
|
masked_latents_name = None
|
|
|
|
mask_name = context.tensors.save(tensor=mask)
|
|
|
|
return DenoiseMaskOutput.build(
|
|
mask_name=mask_name,
|
|
masked_latents_name=masked_latents_name,
|
|
gradient=False,
|
|
)
|
|
|
|
|
|
@invocation_output("gradient_mask_output")
|
|
class GradientMaskOutput(BaseInvocationOutput):
|
|
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
|
|
|
|
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
|
expanded_mask_area: ImageField = OutputField(
|
|
description="Image representing the total gradient area of the mask. For paste-back purposes."
|
|
)
|
|
|
|
|
|
@invocation(
|
|
"create_gradient_mask",
|
|
title="Create Gradient Mask",
|
|
tags=["mask", "denoise"],
|
|
category="latents",
|
|
version="1.1.0",
|
|
)
|
|
class CreateGradientMaskInvocation(BaseInvocation):
|
|
"""Creates mask for denoising model run."""
|
|
|
|
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
|
|
edge_radius: int = InputField(
|
|
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
|
|
)
|
|
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
|
|
minimum_denoise: float = InputField(
|
|
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
|
|
)
|
|
image: Optional[ImageField] = InputField(
|
|
default=None,
|
|
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
|
|
title="[OPTIONAL] Image",
|
|
ui_order=6,
|
|
)
|
|
unet: Optional[UNetField] = InputField(
|
|
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
|
|
default=None,
|
|
input=Input.Connection,
|
|
title="[OPTIONAL] UNet",
|
|
ui_order=5,
|
|
)
|
|
vae: Optional[VAEField] = InputField(
|
|
default=None,
|
|
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
|
|
title="[OPTIONAL] VAE",
|
|
input=Input.Connection,
|
|
ui_order=7,
|
|
)
|
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
|
|
fp32: bool = InputField(
|
|
default=DEFAULT_PRECISION == "float32",
|
|
description=FieldDescriptions.fp32,
|
|
ui_order=9,
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
|
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
|
if self.edge_radius > 0:
|
|
if self.coherence_mode == "Box Blur":
|
|
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
|
else: # Gaussian Blur OR Staged
|
|
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
|
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
|
|
|
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
|
|
|
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
|
blur_tensor = (blur_tensor - 0.5) * 2
|
|
|
|
threshold = 1 - self.minimum_denoise
|
|
|
|
if self.coherence_mode == "Staged":
|
|
# wherever the blur_tensor is less than fully masked, convert it to threshold
|
|
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
|
|
else:
|
|
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
|
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
|
|
|
else:
|
|
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
|
|
|
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
|
|
|
# compute a [0, 1] mask from the blur_tensor
|
|
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
|
|
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
|
expanded_image_dto = context.images.save(expanded_mask_image)
|
|
|
|
masked_latents_name = None
|
|
if self.unet is not None and self.vae is not None and self.image is not None:
|
|
# all three fields must be present at the same time
|
|
main_model_config = context.models.get_config(self.unet.unet.key)
|
|
assert isinstance(main_model_config, MainConfigBase)
|
|
if main_model_config.variant is ModelVariantType.Inpaint:
|
|
mask = blur_tensor
|
|
vae_info: LoadedModel = context.models.load(self.vae.vae)
|
|
image = context.images.get_pil(self.image.image_name)
|
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
|
if image_tensor.dim() == 3:
|
|
image_tensor = image_tensor.unsqueeze(0)
|
|
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
|
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
|
masked_latents = ImageToLatentsInvocation.vae_encode(
|
|
vae_info, self.fp32, self.tiled, masked_image.clone()
|
|
)
|
|
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
|
|
|
return GradientMaskOutput(
|
|
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
|
|
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
|
)
|
|
|
|
|
|
def get_scheduler(
|
|
context: InvocationContext,
|
|
scheduler_info: ModelIdentifierField,
|
|
scheduler_name: str,
|
|
seed: int,
|
|
) -> Scheduler:
|
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
|
orig_scheduler_info = context.models.load(scheduler_info)
|
|
with orig_scheduler_info as orig_scheduler:
|
|
scheduler_config = orig_scheduler.config
|
|
|
|
if "_backup" in scheduler_config:
|
|
scheduler_config = scheduler_config["_backup"]
|
|
scheduler_config = {
|
|
**scheduler_config,
|
|
**scheduler_extra_config, # FIXME
|
|
"_backup": scheduler_config,
|
|
}
|
|
|
|
# make dpmpp_sde reproducable(seed can be passed only in initializer)
|
|
if scheduler_class is DPMSolverSDEScheduler:
|
|
scheduler_config["noise_sampler_seed"] = seed
|
|
|
|
scheduler = scheduler_class.from_config(scheduler_config)
|
|
|
|
# hack copied over from generate.py
|
|
if not hasattr(scheduler, "uses_inpainting_model"):
|
|
scheduler.uses_inpainting_model = lambda: False
|
|
assert isinstance(scheduler, Scheduler)
|
|
return scheduler
|
|
|
|
|
|
@invocation(
|
|
"denoise_latents",
|
|
title="Denoise Latents",
|
|
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
|
category="latents",
|
|
version="1.5.3",
|
|
)
|
|
class DenoiseLatentsInvocation(BaseInvocation):
|
|
"""Denoises noisy latents to decodable images"""
|
|
|
|
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
|
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
|
)
|
|
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
|
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
|
)
|
|
noise: Optional[LatentsField] = InputField(
|
|
default=None,
|
|
description=FieldDescriptions.noise,
|
|
input=Input.Connection,
|
|
ui_order=3,
|
|
)
|
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
|
cfg_scale: Union[float, List[float]] = InputField(
|
|
default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
|
)
|
|
denoising_start: float = InputField(
|
|
default=0.0,
|
|
ge=0,
|
|
le=1,
|
|
description=FieldDescriptions.denoising_start,
|
|
)
|
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
|
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
|
default="euler",
|
|
description=FieldDescriptions.scheduler,
|
|
ui_type=UIType.Scheduler,
|
|
)
|
|
unet: UNetField = InputField(
|
|
description=FieldDescriptions.unet,
|
|
input=Input.Connection,
|
|
title="UNet",
|
|
ui_order=2,
|
|
)
|
|
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
|
default=None,
|
|
input=Input.Connection,
|
|
ui_order=5,
|
|
)
|
|
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
|
|
description=FieldDescriptions.ip_adapter,
|
|
title="IP-Adapter",
|
|
default=None,
|
|
input=Input.Connection,
|
|
ui_order=6,
|
|
)
|
|
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
|
|
description=FieldDescriptions.t2i_adapter,
|
|
title="T2I-Adapter",
|
|
default=None,
|
|
input=Input.Connection,
|
|
ui_order=7,
|
|
)
|
|
cfg_rescale_multiplier: float = InputField(
|
|
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
|
)
|
|
latents: Optional[LatentsField] = InputField(
|
|
default=None,
|
|
description=FieldDescriptions.latents,
|
|
input=Input.Connection,
|
|
ui_order=4,
|
|
)
|
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
|
default=None,
|
|
description=FieldDescriptions.mask,
|
|
input=Input.Connection,
|
|
ui_order=8,
|
|
)
|
|
|
|
@field_validator("cfg_scale")
|
|
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
|
|
"""validate that all cfg_scale values are >= 1"""
|
|
if isinstance(v, list):
|
|
for i in v:
|
|
if i < 1:
|
|
raise ValueError("cfg_scale must be greater than 1")
|
|
else:
|
|
if v < 1:
|
|
raise ValueError("cfg_scale must be greater than 1")
|
|
return v
|
|
|
|
def _get_text_embeddings_and_masks(
|
|
self,
|
|
cond_list: list[ConditioningField],
|
|
context: InvocationContext,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
|
|
"""Get the text embeddings and masks from the input conditioning fields."""
|
|
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
|
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
|
for cond in cond_list:
|
|
cond_data = context.conditioning.load(cond.conditioning_name)
|
|
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
|
|
|
mask = cond.mask
|
|
if mask is not None:
|
|
mask = context.tensors.load(mask.tensor_name)
|
|
text_embeddings_masks.append(mask)
|
|
|
|
return text_embeddings, text_embeddings_masks
|
|
|
|
def _preprocess_regional_prompt_mask(
|
|
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
|
) -> torch.Tensor:
|
|
"""Preprocess a regional prompt mask to match the target height and width.
|
|
If mask is None, returns a mask of all ones with the target height and width.
|
|
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
|
|
|
|
Returns:
|
|
torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
|
|
"""
|
|
|
|
if mask is None:
|
|
return torch.ones((1, 1, target_height, target_width), dtype=dtype)
|
|
|
|
mask = to_standard_float_mask(mask, out_dtype=dtype)
|
|
|
|
tf = torchvision.transforms.Resize(
|
|
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
|
)
|
|
|
|
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
|
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
|
resized_mask = tf(mask)
|
|
return resized_mask
|
|
|
|
def _concat_regional_text_embeddings(
|
|
self,
|
|
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
|
masks: Optional[list[Optional[torch.Tensor]]],
|
|
latent_height: int,
|
|
latent_width: int,
|
|
dtype: torch.dtype,
|
|
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
|
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
|
if masks is None:
|
|
masks = [None] * len(text_conditionings)
|
|
assert len(text_conditionings) == len(masks)
|
|
|
|
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
|
|
|
all_masks_are_none = all(mask is None for mask in masks)
|
|
|
|
text_embedding = []
|
|
pooled_embedding = None
|
|
add_time_ids = None
|
|
cur_text_embedding_len = 0
|
|
processed_masks = []
|
|
embedding_ranges = []
|
|
|
|
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
|
|
mask = masks[prompt_idx]
|
|
|
|
if is_sdxl:
|
|
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
|
|
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
|
|
# global prompt information. In an ideal case, there should be exactly one global prompt without a
|
|
# mask, but we don't enforce this.
|
|
|
|
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
|
|
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
|
|
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
|
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
|
# pretty major breaking change to a popular node, so for now we use this hack.
|
|
if pooled_embedding is None or mask is None:
|
|
pooled_embedding = text_embedding_info.pooled_embeds
|
|
if add_time_ids is None or mask is None:
|
|
add_time_ids = text_embedding_info.add_time_ids
|
|
|
|
text_embedding.append(text_embedding_info.embeds)
|
|
if not all_masks_are_none:
|
|
embedding_ranges.append(
|
|
Range(
|
|
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
|
)
|
|
)
|
|
processed_masks.append(
|
|
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
|
)
|
|
|
|
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
|
|
|
text_embedding = torch.cat(text_embedding, dim=1)
|
|
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
|
|
|
regions = None
|
|
if not all_masks_are_none:
|
|
regions = TextConditioningRegions(
|
|
masks=torch.cat(processed_masks, dim=1),
|
|
ranges=embedding_ranges,
|
|
)
|
|
|
|
if is_sdxl:
|
|
return (
|
|
SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids),
|
|
regions,
|
|
)
|
|
return BasicConditioningInfo(embeds=text_embedding), regions
|
|
|
|
def get_conditioning_data(
|
|
self,
|
|
context: InvocationContext,
|
|
unet: UNet2DConditionModel,
|
|
latent_height: int,
|
|
latent_width: int,
|
|
) -> TextConditioningData:
|
|
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
|
|
cond_list = self.positive_conditioning
|
|
if not isinstance(cond_list, list):
|
|
cond_list = [cond_list]
|
|
uncond_list = self.negative_conditioning
|
|
if not isinstance(uncond_list, list):
|
|
uncond_list = [uncond_list]
|
|
|
|
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
|
cond_list, context, unet.device, unet.dtype
|
|
)
|
|
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
|
uncond_list, context, unet.device, unet.dtype
|
|
)
|
|
|
|
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
|
|
text_conditionings=cond_text_embeddings,
|
|
masks=cond_text_embedding_masks,
|
|
latent_height=latent_height,
|
|
latent_width=latent_width,
|
|
dtype=unet.dtype,
|
|
)
|
|
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
|
|
text_conditionings=uncond_text_embeddings,
|
|
masks=uncond_text_embedding_masks,
|
|
latent_height=latent_height,
|
|
latent_width=latent_width,
|
|
dtype=unet.dtype,
|
|
)
|
|
|
|
if isinstance(self.cfg_scale, list):
|
|
assert (
|
|
len(self.cfg_scale) == self.steps
|
|
), "cfg_scale (list) must have the same length as the number of steps"
|
|
|
|
conditioning_data = TextConditioningData(
|
|
uncond_text=uncond_text_embedding,
|
|
cond_text=cond_text_embedding,
|
|
uncond_regions=uncond_regions,
|
|
cond_regions=cond_regions,
|
|
guidance_scale=self.cfg_scale,
|
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
|
)
|
|
return conditioning_data
|
|
|
|
def create_pipeline(
|
|
self,
|
|
unet: UNet2DConditionModel,
|
|
scheduler: Scheduler,
|
|
) -> StableDiffusionGeneratorPipeline:
|
|
class FakeVae:
|
|
class FakeVaeConfig:
|
|
def __init__(self) -> None:
|
|
self.block_out_channels = [0]
|
|
|
|
def __init__(self) -> None:
|
|
self.config = FakeVae.FakeVaeConfig()
|
|
|
|
return StableDiffusionGeneratorPipeline(
|
|
vae=FakeVae(), # TODO: oh...
|
|
text_encoder=None,
|
|
tokenizer=None,
|
|
unet=unet,
|
|
scheduler=scheduler,
|
|
safety_checker=None,
|
|
feature_extractor=None,
|
|
requires_safety_checker=False,
|
|
)
|
|
|
|
def prep_control_data(
|
|
self,
|
|
context: InvocationContext,
|
|
control_input: Optional[Union[ControlField, List[ControlField]]],
|
|
latents_shape: List[int],
|
|
exit_stack: ExitStack,
|
|
do_classifier_free_guidance: bool = True,
|
|
) -> Optional[List[ControlNetData]]:
|
|
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
|
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
|
|
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
|
|
if control_input is None:
|
|
control_list = None
|
|
elif isinstance(control_input, list) and len(control_input) == 0:
|
|
control_list = None
|
|
elif isinstance(control_input, ControlField):
|
|
control_list = [control_input]
|
|
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
|
control_list = control_input
|
|
else:
|
|
control_list = None
|
|
if control_list is None:
|
|
return None
|
|
# After above handling, any control that is not None should now be of type list[ControlField].
|
|
|
|
# FIXME: add checks to skip entry if model or image is None
|
|
# and if weight is None, populate with default 1.0?
|
|
controlnet_data = []
|
|
for control_info in control_list:
|
|
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
|
|
|
# control_models.append(control_model)
|
|
control_image_field = control_info.image
|
|
input_image = context.images.get_pil(control_image_field.image_name)
|
|
# self.image.image_type, self.image.image_name
|
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
|
# and add in batch_size, num_images_per_prompt?
|
|
# and do real check for classifier_free_guidance?
|
|
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
|
control_image = prepare_control_image(
|
|
image=input_image,
|
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
width=control_width_resize,
|
|
height=control_height_resize,
|
|
# batch_size=batch_size * num_images_per_prompt,
|
|
# num_images_per_prompt=num_images_per_prompt,
|
|
device=control_model.device,
|
|
dtype=control_model.dtype,
|
|
control_mode=control_info.control_mode,
|
|
resize_mode=control_info.resize_mode,
|
|
)
|
|
control_item = ControlNetData(
|
|
model=control_model, # model object
|
|
image_tensor=control_image,
|
|
weight=control_info.control_weight,
|
|
begin_step_percent=control_info.begin_step_percent,
|
|
end_step_percent=control_info.end_step_percent,
|
|
control_mode=control_info.control_mode,
|
|
# any resizing needed should currently be happening in prepare_control_image(),
|
|
# but adding resize_mode to ControlNetData in case needed in the future
|
|
resize_mode=control_info.resize_mode,
|
|
)
|
|
controlnet_data.append(control_item)
|
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
|
|
|
return controlnet_data
|
|
|
|
def prep_ip_adapter_image_prompts(
|
|
self,
|
|
context: InvocationContext,
|
|
ip_adapters: List[IPAdapterField],
|
|
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
|
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
|
|
image_prompts = []
|
|
for single_ip_adapter in ip_adapters:
|
|
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
|
|
assert isinstance(ip_adapter_model, IPAdapter)
|
|
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
|
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
|
single_ipa_image_fields = single_ip_adapter.image
|
|
if not isinstance(single_ipa_image_fields, list):
|
|
single_ipa_image_fields = [single_ipa_image_fields]
|
|
|
|
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
|
|
with image_encoder_model_info as image_encoder_model:
|
|
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
|
# Get image embeddings from CLIP and ImageProjModel.
|
|
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
|
single_ipa_images, image_encoder_model
|
|
)
|
|
image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds))
|
|
|
|
return image_prompts
|
|
|
|
def prep_ip_adapter_data(
|
|
self,
|
|
context: InvocationContext,
|
|
ip_adapters: List[IPAdapterField],
|
|
image_prompts: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
exit_stack: ExitStack,
|
|
latent_height: int,
|
|
latent_width: int,
|
|
dtype: torch.dtype,
|
|
) -> Optional[List[IPAdapterData]]:
|
|
"""If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data."""
|
|
ip_adapter_data_list = []
|
|
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
|
|
ip_adapters, image_prompts, strict=True
|
|
):
|
|
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
|
|
|
|
mask_field = single_ip_adapter.mask
|
|
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
|
|
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
|
|
|
ip_adapter_data_list.append(
|
|
IPAdapterData(
|
|
ip_adapter_model=ip_adapter_model,
|
|
weight=single_ip_adapter.weight,
|
|
target_blocks=single_ip_adapter.target_blocks,
|
|
begin_step_percent=single_ip_adapter.begin_step_percent,
|
|
end_step_percent=single_ip_adapter.end_step_percent,
|
|
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
|
mask=mask,
|
|
)
|
|
)
|
|
|
|
return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None
|
|
|
|
def run_t2i_adapters(
|
|
self,
|
|
context: InvocationContext,
|
|
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
|
latents_shape: list[int],
|
|
do_classifier_free_guidance: bool,
|
|
) -> Optional[list[T2IAdapterData]]:
|
|
if t2i_adapter is None:
|
|
return None
|
|
|
|
# Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField.
|
|
if isinstance(t2i_adapter, T2IAdapterField):
|
|
t2i_adapter = [t2i_adapter]
|
|
|
|
if len(t2i_adapter) == 0:
|
|
return None
|
|
|
|
t2i_adapter_data = []
|
|
for t2i_adapter_field in t2i_adapter:
|
|
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
|
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
|
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
|
|
|
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
|
if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1:
|
|
max_unet_downscale = 8
|
|
elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL:
|
|
max_unet_downscale = 4
|
|
else:
|
|
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
|
|
|
|
t2i_adapter_model: T2IAdapter
|
|
with t2i_adapter_loaded_model as t2i_adapter_model:
|
|
total_downscale_factor = t2i_adapter_model.total_downscale_factor
|
|
|
|
# Resize the T2I-Adapter input image.
|
|
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
|
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
|
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
|
|
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
|
|
|
|
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
|
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
|
|
# T2I-Adapter model.
|
|
#
|
|
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
|
|
# of the same requirements (e.g. preserving binary masks during resize).
|
|
t2i_image = prepare_control_image(
|
|
image=image,
|
|
do_classifier_free_guidance=False,
|
|
width=t2i_input_width,
|
|
height=t2i_input_height,
|
|
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
|
|
device=t2i_adapter_model.device,
|
|
dtype=t2i_adapter_model.dtype,
|
|
resize_mode=t2i_adapter_field.resize_mode,
|
|
)
|
|
|
|
adapter_state = t2i_adapter_model(t2i_image)
|
|
|
|
if do_classifier_free_guidance:
|
|
for idx, value in enumerate(adapter_state):
|
|
adapter_state[idx] = torch.cat([value] * 2, dim=0)
|
|
|
|
t2i_adapter_data.append(
|
|
T2IAdapterData(
|
|
adapter_state=adapter_state,
|
|
weight=t2i_adapter_field.weight,
|
|
begin_step_percent=t2i_adapter_field.begin_step_percent,
|
|
end_step_percent=t2i_adapter_field.end_step_percent,
|
|
)
|
|
)
|
|
|
|
return t2i_adapter_data
|
|
|
|
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
|
# TODO: research more for second order schedulers timesteps
|
|
def init_scheduler(
|
|
self,
|
|
scheduler: Union[Scheduler, ConfigMixin],
|
|
device: torch.device,
|
|
steps: int,
|
|
denoising_start: float,
|
|
denoising_end: float,
|
|
seed: int,
|
|
) -> Tuple[int, List[int], int, Dict[str, Any]]:
|
|
assert isinstance(scheduler, ConfigMixin)
|
|
if scheduler.config.get("cpu_only", False):
|
|
scheduler.set_timesteps(steps, device="cpu")
|
|
timesteps = scheduler.timesteps.to(device=device)
|
|
else:
|
|
scheduler.set_timesteps(steps, device=device)
|
|
timesteps = scheduler.timesteps
|
|
|
|
# skip greater order timesteps
|
|
_timesteps = timesteps[:: scheduler.order]
|
|
|
|
# get start timestep index
|
|
t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
|
|
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
|
|
|
|
# get end timestep index
|
|
t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
|
|
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
|
|
|
|
# apply order to indexes
|
|
t_start_idx *= scheduler.order
|
|
t_end_idx *= scheduler.order
|
|
|
|
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
|
|
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
|
num_inference_steps = len(timesteps) // scheduler.order
|
|
|
|
scheduler_step_kwargs: Dict[str, Any] = {}
|
|
scheduler_step_signature = inspect.signature(scheduler.step)
|
|
if "generator" in scheduler_step_signature.parameters:
|
|
# At some point, someone decided that schedulers that accept a generator should use the original seed with
|
|
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
|
|
# reproducibility.
|
|
#
|
|
# These Invoke-supported schedulers accept a generator as of 2024-06-04:
|
|
# - DDIMScheduler
|
|
# - DDPMScheduler
|
|
# - DPMSolverMultistepScheduler
|
|
# - EulerAncestralDiscreteScheduler
|
|
# - EulerDiscreteScheduler
|
|
# - KDPM2AncestralDiscreteScheduler
|
|
# - LCMScheduler
|
|
# - TCDScheduler
|
|
scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
|
|
if isinstance(scheduler, TCDScheduler):
|
|
scheduler_step_kwargs.update({"eta": 1.0})
|
|
|
|
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
|
|
|
|
def prep_inpaint_mask(
|
|
self, context: InvocationContext, latents: torch.Tensor
|
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
|
|
if self.denoise_mask is None:
|
|
return None, None, False
|
|
|
|
mask = context.tensors.load(self.denoise_mask.mask_name)
|
|
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
|
if self.denoise_mask.masked_latents_name is not None:
|
|
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
|
|
else:
|
|
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
|
|
|
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
|
|
|
@torch.no_grad()
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
|
seed = None
|
|
noise = None
|
|
if self.noise is not None:
|
|
noise = context.tensors.load(self.noise.latents_name)
|
|
seed = self.noise.seed
|
|
|
|
if self.latents is not None:
|
|
latents = context.tensors.load(self.latents.latents_name)
|
|
if seed is None:
|
|
seed = self.latents.seed
|
|
|
|
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
|
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
|
|
|
elif noise is not None:
|
|
latents = torch.zeros_like(noise)
|
|
else:
|
|
raise Exception("'latents' or 'noise' must be provided!")
|
|
|
|
if seed is None:
|
|
seed = 0
|
|
|
|
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
|
|
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
|
# below. Investigate whether this is appropriate.
|
|
t2i_adapter_data = self.run_t2i_adapters(
|
|
context,
|
|
self.t2i_adapter,
|
|
latents.shape,
|
|
do_classifier_free_guidance=True,
|
|
)
|
|
|
|
ip_adapters: List[IPAdapterField] = []
|
|
if self.ip_adapter is not None:
|
|
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
|
if isinstance(self.ip_adapter, list):
|
|
ip_adapters = self.ip_adapter
|
|
else:
|
|
ip_adapters = [self.ip_adapter]
|
|
|
|
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
|
|
# a series of image conditioning embeddings. This is being done here rather than in the
|
|
# big model context below in order to use less VRAM on low-VRAM systems.
|
|
# The image prompts are then passed to prep_ip_adapter_data().
|
|
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
|
|
|
|
# get the unet's config so that we can pass the base to dispatch_progress()
|
|
unet_config = context.models.get_config(self.unet.unet.key)
|
|
|
|
def step_callback(state: PipelineIntermediateState) -> None:
|
|
context.util.sd_step_callback(state, unet_config.base)
|
|
|
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
|
for lora in self.unet.loras:
|
|
lora_info = context.models.load(lora.lora)
|
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
|
yield (lora_info.model, lora.weight)
|
|
del lora_info
|
|
return
|
|
|
|
unet_info = context.models.load(self.unet.unet)
|
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
|
with (
|
|
ExitStack() as exit_stack,
|
|
unet_info as unet,
|
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
|
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
|
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
|
):
|
|
assert isinstance(unet, UNet2DConditionModel)
|
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
|
if noise is not None:
|
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
|
if mask is not None:
|
|
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
|
if masked_latents is not None:
|
|
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
|
|
|
|
scheduler = get_scheduler(
|
|
context=context,
|
|
scheduler_info=self.unet.scheduler,
|
|
scheduler_name=self.scheduler,
|
|
seed=seed,
|
|
)
|
|
|
|
pipeline = self.create_pipeline(unet, scheduler)
|
|
|
|
_, _, latent_height, latent_width = latents.shape
|
|
conditioning_data = self.get_conditioning_data(
|
|
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
|
)
|
|
|
|
controlnet_data = self.prep_control_data(
|
|
context=context,
|
|
control_input=self.control,
|
|
latents_shape=latents.shape,
|
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
|
do_classifier_free_guidance=True,
|
|
exit_stack=exit_stack,
|
|
)
|
|
|
|
ip_adapter_data = self.prep_ip_adapter_data(
|
|
context=context,
|
|
ip_adapters=ip_adapters,
|
|
image_prompts=image_prompts,
|
|
exit_stack=exit_stack,
|
|
latent_height=latent_height,
|
|
latent_width=latent_width,
|
|
dtype=unet.dtype,
|
|
)
|
|
|
|
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
|
scheduler,
|
|
device=unet.device,
|
|
steps=self.steps,
|
|
denoising_start=self.denoising_start,
|
|
denoising_end=self.denoising_end,
|
|
seed=seed,
|
|
)
|
|
|
|
result_latents = pipeline.latents_from_embeddings(
|
|
latents=latents,
|
|
timesteps=timesteps,
|
|
init_timestep=init_timestep,
|
|
noise=noise,
|
|
seed=seed,
|
|
mask=mask,
|
|
masked_latents=masked_latents,
|
|
gradient_mask=gradient_mask,
|
|
num_inference_steps=num_inference_steps,
|
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
conditioning_data=conditioning_data,
|
|
control_data=controlnet_data,
|
|
ip_adapter_data=ip_adapter_data,
|
|
t2i_adapter_data=t2i_adapter_data,
|
|
callback=step_callback,
|
|
)
|
|
|
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
result_latents = result_latents.to("cpu")
|
|
TorchDevice.empty_cache()
|
|
|
|
name = context.tensors.save(tensor=result_latents)
|
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
|
|
|
|
|
@invocation(
|
|
"l2i",
|
|
title="Latents to Image",
|
|
tags=["latents", "image", "vae", "l2i"],
|
|
category="latents",
|
|
version="1.2.2",
|
|
)
|
|
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
"""Generates an image from latents."""
|
|
|
|
latents: LatentsField = InputField(
|
|
description=FieldDescriptions.latents,
|
|
input=Input.Connection,
|
|
)
|
|
vae: VAEField = InputField(
|
|
description=FieldDescriptions.vae,
|
|
input=Input.Connection,
|
|
)
|
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
|
|
|
@torch.no_grad()
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
latents = context.tensors.load(self.latents.latents_name)
|
|
|
|
vae_info = context.models.load(self.vae.vae)
|
|
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
|
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
|
assert isinstance(vae, torch.nn.Module)
|
|
latents = latents.to(vae.device)
|
|
if self.fp32:
|
|
vae.to(dtype=torch.float32)
|
|
|
|
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
|
vae.decoder.mid_block.attentions[0].processor,
|
|
(
|
|
AttnProcessor2_0,
|
|
XFormersAttnProcessor,
|
|
LoRAXFormersAttnProcessor,
|
|
LoRAAttnProcessor2_0,
|
|
),
|
|
)
|
|
# if xformers or torch_2_0 is used attention block does not need
|
|
# to be in float32 which can save lots of memory
|
|
if use_torch_2_0_or_xformers:
|
|
vae.post_quant_conv.to(latents.dtype)
|
|
vae.decoder.conv_in.to(latents.dtype)
|
|
vae.decoder.mid_block.to(latents.dtype)
|
|
else:
|
|
latents = latents.float()
|
|
|
|
else:
|
|
vae.to(dtype=torch.float16)
|
|
latents = latents.half()
|
|
|
|
if self.tiled or context.config.get().force_tiled_decode:
|
|
vae.enable_tiling()
|
|
else:
|
|
vae.disable_tiling()
|
|
|
|
# clear memory as vae decode can request a lot
|
|
TorchDevice.empty_cache()
|
|
|
|
with torch.inference_mode():
|
|
# copied from diffusers pipeline
|
|
latents = latents / vae.config.scaling_factor
|
|
image = vae.decode(latents, return_dict=False)[0]
|
|
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
|
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
|
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
|
|
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
|
|
|
TorchDevice.empty_cache()
|
|
|
|
image_dto = context.images.save(image=image)
|
|
|
|
return ImageOutput.build(image_dto)
|
|
|
|
|
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
|
|
|
|
|
@invocation(
|
|
"lresize",
|
|
title="Resize Latents",
|
|
tags=["latents", "resize"],
|
|
category="latents",
|
|
version="1.0.2",
|
|
)
|
|
class ResizeLatentsInvocation(BaseInvocation):
|
|
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
|
|
|
latents: LatentsField = InputField(
|
|
description=FieldDescriptions.latents,
|
|
input=Input.Connection,
|
|
)
|
|
width: int = InputField(
|
|
ge=64,
|
|
multiple_of=LATENT_SCALE_FACTOR,
|
|
description=FieldDescriptions.width,
|
|
)
|
|
height: int = InputField(
|
|
ge=64,
|
|
multiple_of=LATENT_SCALE_FACTOR,
|
|
description=FieldDescriptions.width,
|
|
)
|
|
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
|
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
latents = context.tensors.load(self.latents.latents_name)
|
|
device = TorchDevice.choose_torch_device()
|
|
|
|
resized_latents = torch.nn.functional.interpolate(
|
|
latents.to(device),
|
|
size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR),
|
|
mode=self.mode,
|
|
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
|
)
|
|
|
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
resized_latents = resized_latents.to("cpu")
|
|
|
|
TorchDevice.empty_cache()
|
|
|
|
name = context.tensors.save(tensor=resized_latents)
|
|
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
|
|
|
|
|
@invocation(
|
|
"lscale",
|
|
title="Scale Latents",
|
|
tags=["latents", "resize"],
|
|
category="latents",
|
|
version="1.0.2",
|
|
)
|
|
class ScaleLatentsInvocation(BaseInvocation):
|
|
"""Scales latents by a given factor."""
|
|
|
|
latents: LatentsField = InputField(
|
|
description=FieldDescriptions.latents,
|
|
input=Input.Connection,
|
|
)
|
|
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
|
|
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
|
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
latents = context.tensors.load(self.latents.latents_name)
|
|
|
|
device = TorchDevice.choose_torch_device()
|
|
|
|
# resizing
|
|
resized_latents = torch.nn.functional.interpolate(
|
|
latents.to(device),
|
|
scale_factor=self.scale_factor,
|
|
mode=self.mode,
|
|
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
|
)
|
|
|
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
resized_latents = resized_latents.to("cpu")
|
|
TorchDevice.empty_cache()
|
|
|
|
name = context.tensors.save(tensor=resized_latents)
|
|
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
|
|
|
|
|
@invocation(
|
|
"i2l",
|
|
title="Image to Latents",
|
|
tags=["latents", "image", "vae", "i2l"],
|
|
category="latents",
|
|
version="1.0.2",
|
|
)
|
|
class ImageToLatentsInvocation(BaseInvocation):
|
|
"""Encodes an image into latents."""
|
|
|
|
image: ImageField = InputField(
|
|
description="The image to encode",
|
|
)
|
|
vae: VAEField = InputField(
|
|
description=FieldDescriptions.vae,
|
|
input=Input.Connection,
|
|
)
|
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
|
|
|
@staticmethod
|
|
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
|
|
with vae_info as vae:
|
|
assert isinstance(vae, torch.nn.Module)
|
|
orig_dtype = vae.dtype
|
|
if upcast:
|
|
vae.to(dtype=torch.float32)
|
|
|
|
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
|
vae.decoder.mid_block.attentions[0].processor,
|
|
(
|
|
AttnProcessor2_0,
|
|
XFormersAttnProcessor,
|
|
LoRAXFormersAttnProcessor,
|
|
LoRAAttnProcessor2_0,
|
|
),
|
|
)
|
|
# if xformers or torch_2_0 is used attention block does not need
|
|
# to be in float32 which can save lots of memory
|
|
if use_torch_2_0_or_xformers:
|
|
vae.post_quant_conv.to(orig_dtype)
|
|
vae.decoder.conv_in.to(orig_dtype)
|
|
vae.decoder.mid_block.to(orig_dtype)
|
|
# else:
|
|
# latents = latents.float()
|
|
|
|
else:
|
|
vae.to(dtype=torch.float16)
|
|
# latents = latents.half()
|
|
|
|
if tiled:
|
|
vae.enable_tiling()
|
|
else:
|
|
vae.disable_tiling()
|
|
|
|
# non_noised_latents_from_image
|
|
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
|
with torch.inference_mode():
|
|
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
|
|
|
|
latents = vae.config.scaling_factor * latents
|
|
latents = latents.to(dtype=orig_dtype)
|
|
|
|
return latents
|
|
|
|
@torch.no_grad()
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
image = context.images.get_pil(self.image.image_name)
|
|
|
|
vae_info = context.models.load(self.vae.vae)
|
|
|
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
|
if image_tensor.dim() == 3:
|
|
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
|
|
|
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
|
|
|
latents = latents.to("cpu")
|
|
name = context.tensors.save(tensor=latents)
|
|
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
|
|
|
@singledispatchmethod
|
|
@staticmethod
|
|
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
|
assert isinstance(vae, torch.nn.Module)
|
|
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
|
latents: torch.Tensor = image_tensor_dist.sample().to(
|
|
dtype=vae.dtype
|
|
) # FIXME: uses torch.randn. make reproducible!
|
|
return latents
|
|
|
|
@_encode_to_tensor.register
|
|
@staticmethod
|
|
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
|
assert isinstance(vae, torch.nn.Module)
|
|
latents: torch.FloatTensor = vae.encode(image_tensor).latents
|
|
return latents
|