From ef8177ba4b92f5b6614c9998802beb50e2ce2572 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sat, 24 Feb 2024 18:28:27 -0500 Subject: [PATCH] Get area-based regional prompting working. --- invokeai/app/invocations/primitives.py | 6 +++- .../diffusion/conditioning_data.py | 2 +- .../diffusion/shared_invokeai_diffusion.py | 28 +++++++++++++------ 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 8c30417eff..9254befd62 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -423,6 +423,7 @@ class ColorInvocation(BaseInvocation): # region Conditioning + class DenoisingArea(BaseModel): top_y: int = Field(description="TODO") left_x: int = Field(description="TODO") @@ -439,7 +440,10 @@ class ConditioningField(BaseModel): # description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included " # "regions should be set to 1.", # ) - denoising_area: DenoisingArea = Field(description="The area to apply this conditioning to. Coordinates are in latent space.") + denoising_area: Optional[DenoisingArea] = Field( + description="The area to apply this conditioning to. Coordinates are in latent space.", + default=None, + ) mask_strength: float = Field( default=1.0, description="The strength of the mask. Only has an effect if mask_name is set. The strength is relative to " diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 892de0ad50..00cb5375d3 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -45,7 +45,7 @@ class TextConditioningInfoWithArea: def __init__( self, text_conditioning_info: Union[BasicConditioningInfo, SDXLConditioningInfo], - area: DenoisingArea, + area: Optional[DenoisingArea], mask_strength: float, ): self.text_conditioning_info = text_conditioning_info diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index e62bae75be..4bc9db29fe 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -5,11 +5,10 @@ from contextlib import contextmanager from typing import Any, Callable, Optional, Union import torch -import torchvision from diffusers import UNet2DConditionModel from typing_extensions import TypeAlias -from invokeai.app.invocations.primitives import DenoisingArea +from invokeai.app.invocations.primitives import DenoisingArea from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, @@ -256,13 +255,21 @@ class InvokeAIDiffuserComponent: # ) else: area = text_conditioning.area - # TODO(ryand): Use LATENT_SCALE_FACTOR instead of hard-coding to 8. - latent_area = DenoisingArea( - height=max(1, area.height // 8), - width=max(1, area.width // 8), - top_y=area.top_y // 8, - left_x=area.left_x // 8, - ) + if area is None: + latent_area = DenoisingArea( + height=latent_shape[2], + width=latent_shape[3], + top_y=0, + left_x=0, + ) + else: + # TODO(ryand): Use LATENT_SCALE_FACTOR instead of hard-coding to 8. + latent_area = DenoisingArea( + height=max(1, area.height // 8), + width=max(1, area.width // 8), + top_y=area.top_y // 8, + left_x=area.left_x // 8, + ) area_sample = sample[ :, :, @@ -288,16 +295,19 @@ class InvokeAIDiffuserComponent: cond_next_xs[ :, i, + :, latent_area.top_y : latent_area.top_y + latent_area.height, latent_area.left_x : latent_area.left_x + latent_area.width, ] = conditioned_next_x * text_conditioning.mask_strength uncond_next_xs[ :, i, + :, latent_area.top_y : latent_area.top_y + latent_area.height, latent_area.left_x : latent_area.left_x + latent_area.width, ] = unconditioned_next_x * text_conditioning.mask_strength cond_count[ + :, :, latent_area.top_y : latent_area.top_y + latent_area.height, latent_area.left_x : latent_area.left_x + latent_area.width,