diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index ff13658052..620c1acb1e 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -5,7 +5,15 @@ from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent +from invokeai.app.invocations.fields import ( + ConditioningField, + FieldDescriptions, + Input, + InputField, + MaskField, + OutputField, + UIComponent, +) from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import generate_ti_list @@ -36,7 +44,7 @@ from .model import ClipField title="Prompt", tags=["prompt", "compel"], category="conditioning", - version="1.0.1", + version="1.1.0", ) class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -51,6 +59,10 @@ class CompelInvocation(BaseInvocation): description=FieldDescriptions.clip, input=Input.Connection, ) + mask: Optional[MaskField] = InputField( + default=None, description="A mask defining the region that this conditioning prompt applies to." + ) + positive_cross_attn_mask_score: float = InputField(default=0.0, description="") @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: @@ -118,7 +130,13 @@ class CompelInvocation(BaseInvocation): conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + mask=self.mask, + positive_cross_attn_mask_score=self.positive_cross_attn_mask_score, + ) + ) class SDXLPromptInvocationBase: @@ -232,7 +250,7 @@ class SDXLPromptInvocationBase: title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.1", + version="1.1.0", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -256,6 +274,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1") clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") + mask: Optional[MaskField] = InputField( + default=None, description="A mask defining the region that this conditioning prompt applies to." + ) + positive_cross_attn_mask_score: float = InputField(default=0.0, description="") + @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: c1, c1_pooled, ec1 = self.run_clip_compel( @@ -317,7 +340,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + mask=self.mask, + positive_cross_attn_mask_score=self.positive_cross_attn_mask_score, + ) + ) @invocation( @@ -366,7 +395,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name)) @invocation_output("clip_skip_output") diff --git a/invokeai/app/invocations/conditioning.py b/invokeai/app/invocations/conditioning.py index 595a20e186..7fbc879be2 100644 --- a/invokeai/app/invocations/conditioning.py +++ b/invokeai/app/invocations/conditioning.py @@ -6,27 +6,7 @@ from invokeai.app.invocations.baseinvocation import ( invocation, ) from invokeai.app.invocations.fields import InputField, WithMetadata -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, MaskField, MaskOutput - - -@invocation( - "add_conditioning_mask", - title="Add Conditioning Mask", - tags=["conditioning"], - category="conditioning", - version="1.0.0", -) -class AddConditioningMaskInvocation(BaseInvocation): - """Add a mask to an existing conditioning tensor.""" - - conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.") - mask: MaskField = InputField(description="A mask to add to the conditioning tensor.") - positive_cross_attn_mask_score: float = InputField(default=0.0, description="") - - def invoke(self, context: InvocationContext) -> ConditioningOutput: - self.conditioning.mask = self.mask - self.conditioning.positive_cross_attn_mask_score = self.positive_cross_attn_mask_score - return ConditioningOutput(conditioning=self.conditioning) +from invokeai.app.invocations.primitives import MaskField, MaskOutput @invocation( diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 0b1dea7ff3..2b73061bcf 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -427,10 +427,6 @@ class ConditioningOutput(BaseInvocationOutput): conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond) - @classmethod - def build(cls, conditioning_name: str) -> "ConditioningOutput": - return cls(conditioning=ConditioningField(conditioning_name=conditioning_name)) - @invocation_output("conditioning_collection_output") class ConditioningCollectionOutput(BaseInvocationOutput): diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index 57b956924c..82f07385a9 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -47,7 +47,7 @@ class RegionalPromptData: # - Scale by region size. self.negative_cross_attn_mask_score = -10000 # self.positive_cross_attn_mask_score = 0.0 - self.positive_self_attn_mask_score = 2.0 + self.positive_self_attn_mask_score = 1.0 self.self_attn_mask_end_step_percent = 0.3 # This one is for regional prompting in general, so should be set on the DenoiseLatents node. self.self_attn_score_range = 3.0 @@ -233,6 +233,7 @@ class RegionalPromptData: prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,) # Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len, # query_seq_len) mask. + # TODO(ryand): Is += really the best option here? attn_mask[batch_idx, :, :] += ( prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * self.positive_self_attn_mask_score )