Get area-based regional prompting working.

This commit is contained in:
Ryan Dick 2024-02-24 18:28:27 -05:00
parent bc869cad62
commit ef8177ba4b
3 changed files with 25 additions and 11 deletions

View File

@ -423,6 +423,7 @@ class ColorInvocation(BaseInvocation):
# region Conditioning # region Conditioning
class DenoisingArea(BaseModel): class DenoisingArea(BaseModel):
top_y: int = Field(description="TODO") top_y: int = Field(description="TODO")
left_x: int = Field(description="TODO") left_x: int = Field(description="TODO")
@ -439,7 +440,10 @@ class ConditioningField(BaseModel):
# description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included " # description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included "
# "regions should be set to 1.", # "regions should be set to 1.",
# ) # )
denoising_area: DenoisingArea = Field(description="The area to apply this conditioning to. Coordinates are in latent space.") denoising_area: Optional[DenoisingArea] = Field(
description="The area to apply this conditioning to. Coordinates are in latent space.",
default=None,
)
mask_strength: float = Field( mask_strength: float = Field(
default=1.0, default=1.0,
description="The strength of the mask. Only has an effect if mask_name is set. The strength is relative to " description="The strength of the mask. Only has an effect if mask_name is set. The strength is relative to "

View File

@ -45,7 +45,7 @@ class TextConditioningInfoWithArea:
def __init__( def __init__(
self, self,
text_conditioning_info: Union[BasicConditioningInfo, SDXLConditioningInfo], text_conditioning_info: Union[BasicConditioningInfo, SDXLConditioningInfo],
area: DenoisingArea, area: Optional[DenoisingArea],
mask_strength: float, mask_strength: float,
): ):
self.text_conditioning_info = text_conditioning_info self.text_conditioning_info = text_conditioning_info

View File

@ -5,11 +5,10 @@ from contextlib import contextmanager
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torchvision
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from invokeai.app.invocations.primitives import DenoisingArea
from invokeai.app.invocations.primitives import DenoisingArea
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
@ -256,13 +255,21 @@ class InvokeAIDiffuserComponent:
# ) # )
else: else:
area = text_conditioning.area area = text_conditioning.area
# TODO(ryand): Use LATENT_SCALE_FACTOR instead of hard-coding to 8. if area is None:
latent_area = DenoisingArea( latent_area = DenoisingArea(
height=max(1, area.height // 8), height=latent_shape[2],
width=max(1, area.width // 8), width=latent_shape[3],
top_y=area.top_y // 8, top_y=0,
left_x=area.left_x // 8, left_x=0,
) )
else:
# TODO(ryand): Use LATENT_SCALE_FACTOR instead of hard-coding to 8.
latent_area = DenoisingArea(
height=max(1, area.height // 8),
width=max(1, area.width // 8),
top_y=area.top_y // 8,
left_x=area.left_x // 8,
)
area_sample = sample[ area_sample = sample[
:, :,
:, :,
@ -288,16 +295,19 @@ class InvokeAIDiffuserComponent:
cond_next_xs[ cond_next_xs[
:, :,
i, i,
:,
latent_area.top_y : latent_area.top_y + latent_area.height, latent_area.top_y : latent_area.top_y + latent_area.height,
latent_area.left_x : latent_area.left_x + latent_area.width, latent_area.left_x : latent_area.left_x + latent_area.width,
] = conditioned_next_x * text_conditioning.mask_strength ] = conditioned_next_x * text_conditioning.mask_strength
uncond_next_xs[ uncond_next_xs[
:, :,
i, i,
:,
latent_area.top_y : latent_area.top_y + latent_area.height, latent_area.top_y : latent_area.top_y + latent_area.height,
latent_area.left_x : latent_area.left_x + latent_area.width, latent_area.left_x : latent_area.left_x + latent_area.width,
] = unconditioned_next_x * text_conditioning.mask_strength ] = unconditioned_next_x * text_conditioning.mask_strength
cond_count[ cond_count[
:,
:, :,
latent_area.top_y : latent_area.top_y + latent_area.height, latent_area.top_y : latent_area.top_y + latent_area.height,
latent_area.left_x : latent_area.left_x + latent_area.width, latent_area.left_x : latent_area.left_x + latent_area.width,