mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
First draft of area-based latent-space regional prompting.
This commit is contained in:
parent
32f602ab2a
commit
bc869cad62
@ -9,46 +9,75 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
WithMetadata,
|
||||
invocation,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, DenoisingArea, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
|
||||
|
||||
# @invocation(
|
||||
# "add_conditioning_mask",
|
||||
# title="Add Conditioning Mask",
|
||||
# tags=["conditioning"],
|
||||
# category="conditioning",
|
||||
# version="1.0.0",
|
||||
# )
|
||||
# class AddConditioningMaskInvocation(BaseInvocation):
|
||||
# """Add a mask to an existing conditioning tensor."""
|
||||
|
||||
# conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
|
||||
# mask: ImageField = InputField(
|
||||
# description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. "
|
||||
# "Pixels <128 are excluded from the mask, pixels >=128 are included in the mask."
|
||||
# )
|
||||
# mask_strength: float = InputField(
|
||||
# description="The strength of the mask to apply to the conditioning tensor.", default=1.0
|
||||
# )
|
||||
|
||||
# @staticmethod
|
||||
# def convert_image_to_mask(image: Image.Image) -> torch.Tensor:
|
||||
# """Convert a PIL image to a uint8 mask tensor."""
|
||||
# np_image = np.array(image)
|
||||
# torch_image = torch.from_numpy(np_image[:, :, 0])
|
||||
# mask = torch_image >= 128
|
||||
# return mask.to(dtype=torch.uint8)
|
||||
|
||||
# def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
# image = context.services.images.get_pil_image(self.mask.image_name)
|
||||
# mask = self.convert_image_to_mask(image)
|
||||
|
||||
# mask_name = f"{context.graph_execution_state_id}__{self.id}_conditioning_mask"
|
||||
# context.services.latents.save(mask_name, mask)
|
||||
|
||||
# self.conditioning.mask_name = mask_name
|
||||
# self.conditioning.mask_strength = self.mask_strength
|
||||
# return ConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
@invocation(
|
||||
"add_conditioning_mask",
|
||||
title="Add Conditioning Mask",
|
||||
"add_conditioning_area",
|
||||
title="Add Conditioning Area",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class AddConditioningMaskInvocation(BaseInvocation):
|
||||
"""Add a mask to an existing conditioning tensor."""
|
||||
class AddConditioningAreaInvocation(BaseInvocation):
|
||||
"""Add a denoising area to an existing conditioning tensor."""
|
||||
|
||||
conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
|
||||
mask: ImageField = InputField(
|
||||
description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. "
|
||||
"Pixels <128 are excluded from the mask, pixels >=128 are included in the mask."
|
||||
)
|
||||
# mask: ImageField = InputField(
|
||||
# description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. "
|
||||
# "Pixels <128 are excluded from the mask, pixels >=128 are included in the mask."
|
||||
# )
|
||||
mask_strength: float = InputField(
|
||||
description="The strength of the mask to apply to the conditioning tensor.", default=1.0
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_image_to_mask(image: Image.Image) -> torch.Tensor:
|
||||
"""Convert a PIL image to a uint8 mask tensor."""
|
||||
np_image = np.array(image)
|
||||
torch_image = torch.from_numpy(np_image[:, :, 0])
|
||||
mask = torch_image >= 128
|
||||
return mask.to(dtype=torch.uint8)
|
||||
height: int = InputField(description="The height of the area in latent space.")
|
||||
width: int = InputField(description="The width of the area in latent space.")
|
||||
y_top: int = InputField(description="The top of the area in latent space.")
|
||||
x_left: int = InputField(description="The left of the area in latent space.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
image = context.services.images.get_pil_image(self.mask.image_name)
|
||||
mask = self.convert_image_to_mask(image)
|
||||
|
||||
mask_name = f"{context.graph_execution_state_id}__{self.id}_conditioning_mask"
|
||||
context.services.latents.save(mask_name, mask)
|
||||
|
||||
self.conditioning.mask_name = mask_name
|
||||
self.conditioning.mask_strength = self.mask_strength
|
||||
self.conditioning.denoising_area = DenoisingArea(height=self.height, width=self.width, top_y=self.y_top, left_x=self.x_left)
|
||||
return ConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
|
||||
|
@ -43,7 +43,7 @@ from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningData,
|
||||
IPAdapterConditioningInfo,
|
||||
TextConditioningInfoWithMask,
|
||||
TextConditioningInfoWithArea,
|
||||
)
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
@ -339,17 +339,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if not isinstance(positive_conditioning_list, list):
|
||||
positive_conditioning_list = [positive_conditioning_list]
|
||||
|
||||
text_embeddings: list[TextConditioningInfoWithMask] = []
|
||||
text_embeddings: list[TextConditioningInfoWithArea] = []
|
||||
for positive_conditioning in positive_conditioning_list:
|
||||
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
|
||||
mask_name = positive_conditioning.mask_name
|
||||
mask = None
|
||||
if mask_name is not None:
|
||||
mask = context.services.latents.get(mask_name)
|
||||
text_embeddings.append(
|
||||
TextConditioningInfoWithMask(
|
||||
TextConditioningInfoWithArea(
|
||||
text_conditioning_info=positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype),
|
||||
mask=mask,
|
||||
area=positive_conditioning.denoising_area,
|
||||
mask_strength=positive_conditioning.mask_strength,
|
||||
)
|
||||
)
|
||||
|
@ -423,16 +423,23 @@ class ColorInvocation(BaseInvocation):
|
||||
|
||||
# region Conditioning
|
||||
|
||||
class DenoisingArea(BaseModel):
|
||||
top_y: int = Field(description="TODO")
|
||||
left_x: int = Field(description="TODO")
|
||||
height: int = Field(description="TODO")
|
||||
width: int = Field(description="TODO")
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
mask_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included "
|
||||
"regions should be set to 1.",
|
||||
)
|
||||
# mask_name: Optional[str] = Field(
|
||||
# default=None,
|
||||
# description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included "
|
||||
# "regions should be set to 1.",
|
||||
# )
|
||||
denoising_area: DenoisingArea = Field(description="The area to apply this conditioning to. Coordinates are in latent space.")
|
||||
mask_strength: float = Field(
|
||||
default=1.0,
|
||||
description="The strength of the mask. Only has an effect if mask_name is set. The strength is relative to "
|
||||
|
@ -5,6 +5,8 @@ from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.primitives import DenoisingArea
|
||||
|
||||
from .cross_attention_control import Arguments
|
||||
|
||||
|
||||
@ -39,15 +41,15 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class TextConditioningInfoWithMask:
|
||||
class TextConditioningInfoWithArea:
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioning_info: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
mask: Optional[torch.Tensor],
|
||||
area: DenoisingArea,
|
||||
mask_strength: float,
|
||||
):
|
||||
self.text_conditioning_info = text_conditioning_info
|
||||
self.mask = mask
|
||||
self.area = area
|
||||
self.mask_strength = mask_strength
|
||||
|
||||
|
||||
@ -74,7 +76,7 @@ class IPAdapterConditioningInfo:
|
||||
@dataclass
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: Union[BasicConditioningInfo, SDXLConditioningInfo]
|
||||
text_embeddings: list[TextConditioningInfoWithMask]
|
||||
text_embeddings: list[TextConditioningInfoWithArea]
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
|
@ -8,6 +8,7 @@ import torch
|
||||
import torchvision
|
||||
from diffusers import UNet2DConditionModel
|
||||
from typing_extensions import TypeAlias
|
||||
from invokeai.app.invocations.primitives import DenoisingArea
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
@ -220,9 +221,21 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||
|
||||
cond_next_xs = []
|
||||
uncond_next_x = None
|
||||
for text_conditioning in conditioning_data.text_embeddings:
|
||||
latent_shape = sample.shape
|
||||
cond_next_xs = torch.zeros(
|
||||
(latent_shape[0], len(conditioning_data.text_embeddings)) + latent_shape[1:],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
uncond_next_xs = torch.zeros(
|
||||
(latent_shape[0], len(conditioning_data.text_embeddings)) + latent_shape[1:],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
# Initialize counts to 1e-10 to avoid divide-by-zero.
|
||||
cond_count = torch.ones_like(sample) * 1e-10
|
||||
|
||||
for i, text_conditioning in enumerate(conditioning_data.text_embeddings):
|
||||
if wants_cross_attention_control or self.sequential_guidance:
|
||||
raise NotImplementedError(
|
||||
"Sequential conditioning has not yet been updated to work with multiple text embeddings."
|
||||
@ -242,11 +255,26 @@ class InvokeAIDiffuserComponent:
|
||||
# down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
# )
|
||||
else:
|
||||
area = text_conditioning.area
|
||||
# TODO(ryand): Use LATENT_SCALE_FACTOR instead of hard-coding to 8.
|
||||
latent_area = DenoisingArea(
|
||||
height=max(1, area.height // 8),
|
||||
width=max(1, area.width // 8),
|
||||
top_y=area.top_y // 8,
|
||||
left_x=area.left_x // 8,
|
||||
)
|
||||
area_sample = sample[
|
||||
:,
|
||||
:,
|
||||
latent_area.top_y : latent_area.top_y + latent_area.height,
|
||||
latent_area.left_x : latent_area.left_x + latent_area.width,
|
||||
]
|
||||
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning(
|
||||
x=sample,
|
||||
x=area_sample,
|
||||
sigma=timestep,
|
||||
cond_text_embedding=text_conditioning.text_conditioning_info,
|
||||
uncond_text_embedding=conditioning_data.unconditioned_embeddings,
|
||||
@ -255,52 +283,32 @@ class InvokeAIDiffuserComponent:
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
)
|
||||
cond_next_xs.append(conditioned_next_x)
|
||||
# HACK(ryand): We re-run unconditioned denoising for each text embedding, but we should only need to do it
|
||||
# once.
|
||||
uncond_next_x = unconditioned_next_x
|
||||
|
||||
# TODO(ryand): Think about how to handle the batch dimension here. Should this be torch.stack()? It probably
|
||||
# doesn't matter, as I'm sure there are many other places where we don't properly support batching.
|
||||
cond_out = torch.concat(cond_next_xs, dim=0)
|
||||
# Initialize count to 1e-9 to avoid division by zero.
|
||||
cond_count = torch.ones_like(cond_out[0, ...]) * 1e-9
|
||||
# TODO(ryand): Apply mask here.
|
||||
cond_next_xs[
|
||||
:,
|
||||
i,
|
||||
latent_area.top_y : latent_area.top_y + latent_area.height,
|
||||
latent_area.left_x : latent_area.left_x + latent_area.width,
|
||||
] = conditioned_next_x * text_conditioning.mask_strength
|
||||
uncond_next_xs[
|
||||
:,
|
||||
i,
|
||||
latent_area.top_y : latent_area.top_y + latent_area.height,
|
||||
latent_area.left_x : latent_area.left_x + latent_area.width,
|
||||
] = unconditioned_next_x * text_conditioning.mask_strength
|
||||
cond_count[
|
||||
:,
|
||||
latent_area.top_y : latent_area.top_y + latent_area.height,
|
||||
latent_area.left_x : latent_area.left_x + latent_area.width,
|
||||
] += text_conditioning.mask_strength
|
||||
|
||||
_, _, height, width = cond_out.shape
|
||||
for te_idx, te in enumerate(conditioning_data.text_embeddings):
|
||||
mask = te.mask
|
||||
if mask is not None:
|
||||
# Resize if necessary.
|
||||
tf = torchvision.transforms.Resize(
|
||||
(height, width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
|
||||
mask = tf(mask)
|
||||
# TODO(ryand): Do other apps apply the same mask weight and count to the unconditioned output?
|
||||
|
||||
# TODO(ryand): We are converting from uint8 to float here. Should we just be storing a float mask to
|
||||
# begin with?
|
||||
mask = mask.to(cond_out.device, cond_out.dtype)
|
||||
cond_next_x = cond_next_xs.sum(dim=1) / cond_count
|
||||
uncond_next_x = uncond_next_xs.sum(dim=1) / cond_count
|
||||
|
||||
# Make sure that all mask values are either 0.0 or 1.0.
|
||||
# HACK(ryand): This is not the right place to be doing this. Just be clear about the expected format of
|
||||
# the mask in the passed data structures.
|
||||
mask[mask < 0.5] = 0.0
|
||||
mask[mask >= 0.5] = 1.0
|
||||
|
||||
mask *= te.mask_strength
|
||||
else:
|
||||
# mask is None, so treat as a mask of all 1.0s (by taking advantage of torch's treatment of scalar
|
||||
# values).
|
||||
mask = 1.0
|
||||
|
||||
# Apply the mask and update the count.
|
||||
cond_out[te_idx, ...] *= mask[0]
|
||||
cond_count += mask[0]
|
||||
|
||||
# Combine the masked conditionings.
|
||||
cond_out = cond_out.sum(dim=0, keepdim=True) / cond_count
|
||||
|
||||
return uncond_next_x, cond_out
|
||||
return uncond_next_x, cond_next_x
|
||||
|
||||
def do_latent_postprocessing(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user