mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Get area-based regional prompting working.
This commit is contained in:
parent
bc869cad62
commit
ef8177ba4b
@ -423,6 +423,7 @@ class ColorInvocation(BaseInvocation):
|
||||
|
||||
# region Conditioning
|
||||
|
||||
|
||||
class DenoisingArea(BaseModel):
|
||||
top_y: int = Field(description="TODO")
|
||||
left_x: int = Field(description="TODO")
|
||||
@ -439,7 +440,10 @@ class ConditioningField(BaseModel):
|
||||
# description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included "
|
||||
# "regions should be set to 1.",
|
||||
# )
|
||||
denoising_area: DenoisingArea = Field(description="The area to apply this conditioning to. Coordinates are in latent space.")
|
||||
denoising_area: Optional[DenoisingArea] = Field(
|
||||
description="The area to apply this conditioning to. Coordinates are in latent space.",
|
||||
default=None,
|
||||
)
|
||||
mask_strength: float = Field(
|
||||
default=1.0,
|
||||
description="The strength of the mask. Only has an effect if mask_name is set. The strength is relative to "
|
||||
|
@ -45,7 +45,7 @@ class TextConditioningInfoWithArea:
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioning_info: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
area: DenoisingArea,
|
||||
area: Optional[DenoisingArea],
|
||||
mask_strength: float,
|
||||
):
|
||||
self.text_conditioning_info = text_conditioning_info
|
||||
|
@ -5,11 +5,10 @@ from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from diffusers import UNet2DConditionModel
|
||||
from typing_extensions import TypeAlias
|
||||
from invokeai.app.invocations.primitives import DenoisingArea
|
||||
|
||||
from invokeai.app.invocations.primitives import DenoisingArea
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
@ -256,13 +255,21 @@ class InvokeAIDiffuserComponent:
|
||||
# )
|
||||
else:
|
||||
area = text_conditioning.area
|
||||
# TODO(ryand): Use LATENT_SCALE_FACTOR instead of hard-coding to 8.
|
||||
latent_area = DenoisingArea(
|
||||
height=max(1, area.height // 8),
|
||||
width=max(1, area.width // 8),
|
||||
top_y=area.top_y // 8,
|
||||
left_x=area.left_x // 8,
|
||||
)
|
||||
if area is None:
|
||||
latent_area = DenoisingArea(
|
||||
height=latent_shape[2],
|
||||
width=latent_shape[3],
|
||||
top_y=0,
|
||||
left_x=0,
|
||||
)
|
||||
else:
|
||||
# TODO(ryand): Use LATENT_SCALE_FACTOR instead of hard-coding to 8.
|
||||
latent_area = DenoisingArea(
|
||||
height=max(1, area.height // 8),
|
||||
width=max(1, area.width // 8),
|
||||
top_y=area.top_y // 8,
|
||||
left_x=area.left_x // 8,
|
||||
)
|
||||
area_sample = sample[
|
||||
:,
|
||||
:,
|
||||
@ -288,16 +295,19 @@ class InvokeAIDiffuserComponent:
|
||||
cond_next_xs[
|
||||
:,
|
||||
i,
|
||||
:,
|
||||
latent_area.top_y : latent_area.top_y + latent_area.height,
|
||||
latent_area.left_x : latent_area.left_x + latent_area.width,
|
||||
] = conditioned_next_x * text_conditioning.mask_strength
|
||||
uncond_next_xs[
|
||||
:,
|
||||
i,
|
||||
:,
|
||||
latent_area.top_y : latent_area.top_y + latent_area.height,
|
||||
latent_area.left_x : latent_area.left_x + latent_area.width,
|
||||
] = unconditioned_next_x * text_conditioning.mask_strength
|
||||
cond_count[
|
||||
:,
|
||||
:,
|
||||
latent_area.top_y : latent_area.top_y + latent_area.height,
|
||||
latent_area.left_x : latent_area.left_x + latent_area.width,
|
||||
|
Loading…
Reference in New Issue
Block a user