mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Get area-based regional prompting working.
This commit is contained in:
parent
bc869cad62
commit
ef8177ba4b
@ -423,6 +423,7 @@ class ColorInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# region Conditioning
|
# region Conditioning
|
||||||
|
|
||||||
|
|
||||||
class DenoisingArea(BaseModel):
|
class DenoisingArea(BaseModel):
|
||||||
top_y: int = Field(description="TODO")
|
top_y: int = Field(description="TODO")
|
||||||
left_x: int = Field(description="TODO")
|
left_x: int = Field(description="TODO")
|
||||||
@ -439,7 +440,10 @@ class ConditioningField(BaseModel):
|
|||||||
# description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included "
|
# description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included "
|
||||||
# "regions should be set to 1.",
|
# "regions should be set to 1.",
|
||||||
# )
|
# )
|
||||||
denoising_area: DenoisingArea = Field(description="The area to apply this conditioning to. Coordinates are in latent space.")
|
denoising_area: Optional[DenoisingArea] = Field(
|
||||||
|
description="The area to apply this conditioning to. Coordinates are in latent space.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
mask_strength: float = Field(
|
mask_strength: float = Field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
description="The strength of the mask. Only has an effect if mask_name is set. The strength is relative to "
|
description="The strength of the mask. Only has an effect if mask_name is set. The strength is relative to "
|
||||||
|
@ -45,7 +45,7 @@ class TextConditioningInfoWithArea:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
text_conditioning_info: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
text_conditioning_info: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
area: DenoisingArea,
|
area: Optional[DenoisingArea],
|
||||||
mask_strength: float,
|
mask_strength: float,
|
||||||
):
|
):
|
||||||
self.text_conditioning_info = text_conditioning_info
|
self.text_conditioning_info = text_conditioning_info
|
||||||
|
@ -5,11 +5,10 @@ from contextlib import contextmanager
|
|||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
from invokeai.app.invocations.primitives import DenoisingArea
|
|
||||||
|
|
||||||
|
from invokeai.app.invocations.primitives import DenoisingArea
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
@ -256,13 +255,21 @@ class InvokeAIDiffuserComponent:
|
|||||||
# )
|
# )
|
||||||
else:
|
else:
|
||||||
area = text_conditioning.area
|
area = text_conditioning.area
|
||||||
# TODO(ryand): Use LATENT_SCALE_FACTOR instead of hard-coding to 8.
|
if area is None:
|
||||||
latent_area = DenoisingArea(
|
latent_area = DenoisingArea(
|
||||||
height=max(1, area.height // 8),
|
height=latent_shape[2],
|
||||||
width=max(1, area.width // 8),
|
width=latent_shape[3],
|
||||||
top_y=area.top_y // 8,
|
top_y=0,
|
||||||
left_x=area.left_x // 8,
|
left_x=0,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# TODO(ryand): Use LATENT_SCALE_FACTOR instead of hard-coding to 8.
|
||||||
|
latent_area = DenoisingArea(
|
||||||
|
height=max(1, area.height // 8),
|
||||||
|
width=max(1, area.width // 8),
|
||||||
|
top_y=area.top_y // 8,
|
||||||
|
left_x=area.left_x // 8,
|
||||||
|
)
|
||||||
area_sample = sample[
|
area_sample = sample[
|
||||||
:,
|
:,
|
||||||
:,
|
:,
|
||||||
@ -288,16 +295,19 @@ class InvokeAIDiffuserComponent:
|
|||||||
cond_next_xs[
|
cond_next_xs[
|
||||||
:,
|
:,
|
||||||
i,
|
i,
|
||||||
|
:,
|
||||||
latent_area.top_y : latent_area.top_y + latent_area.height,
|
latent_area.top_y : latent_area.top_y + latent_area.height,
|
||||||
latent_area.left_x : latent_area.left_x + latent_area.width,
|
latent_area.left_x : latent_area.left_x + latent_area.width,
|
||||||
] = conditioned_next_x * text_conditioning.mask_strength
|
] = conditioned_next_x * text_conditioning.mask_strength
|
||||||
uncond_next_xs[
|
uncond_next_xs[
|
||||||
:,
|
:,
|
||||||
i,
|
i,
|
||||||
|
:,
|
||||||
latent_area.top_y : latent_area.top_y + latent_area.height,
|
latent_area.top_y : latent_area.top_y + latent_area.height,
|
||||||
latent_area.left_x : latent_area.left_x + latent_area.width,
|
latent_area.left_x : latent_area.left_x + latent_area.width,
|
||||||
] = unconditioned_next_x * text_conditioning.mask_strength
|
] = unconditioned_next_x * text_conditioning.mask_strength
|
||||||
cond_count[
|
cond_count[
|
||||||
|
:,
|
||||||
:,
|
:,
|
||||||
latent_area.top_y : latent_area.top_y + latent_area.height,
|
latent_area.top_y : latent_area.top_y + latent_area.height,
|
||||||
latent_area.left_x : latent_area.left_x + latent_area.width,
|
latent_area.left_x : latent_area.left_x + latent_area.width,
|
||||||
|
Loading…
Reference in New Issue
Block a user