Remove AddConditioningMaskInvocaton.

This commit is contained in:
Ryan Dick 2024-03-04 14:11:38 -05:00
parent 271f8f2414
commit d313e5eb70
4 changed files with 38 additions and 32 deletions

View File

@ -5,7 +5,15 @@ from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
MaskField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list from invokeai.app.util.ti_utils import generate_ti_list
@ -36,7 +44,7 @@ from .model import ClipField
title="Prompt", title="Prompt",
tags=["prompt", "compel"], tags=["prompt", "compel"],
category="conditioning", category="conditioning",
version="1.0.1", version="1.1.0",
) )
class CompelInvocation(BaseInvocation): class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -51,6 +59,10 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.clip, description=FieldDescriptions.clip,
input=Input.Connection, input=Input.Connection,
) )
mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -118,7 +130,13 @@ class CompelInvocation(BaseInvocation):
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name) return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
)
)
class SDXLPromptInvocationBase: class SDXLPromptInvocationBase:
@ -232,7 +250,7 @@ class SDXLPromptInvocationBase:
title="SDXL Prompt", title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"], tags=["sdxl", "compel", "prompt"],
category="conditioning", category="conditioning",
version="1.0.1", version="1.1.0",
) )
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -256,6 +274,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1") clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel( c1, c1_pooled, ec1 = self.run_clip_compel(
@ -317,7 +340,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name) return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
)
)
@invocation( @invocation(
@ -366,7 +395,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name) return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name))
@invocation_output("clip_skip_output") @invocation_output("clip_skip_output")

View File

@ -6,27 +6,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation, invocation,
) )
from invokeai.app.invocations.fields import InputField, WithMetadata from invokeai.app.invocations.fields import InputField, WithMetadata
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, MaskField, MaskOutput from invokeai.app.invocations.primitives import MaskField, MaskOutput
@invocation(
"add_conditioning_mask",
title="Add Conditioning Mask",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
)
class AddConditioningMaskInvocation(BaseInvocation):
"""Add a mask to an existing conditioning tensor."""
conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
mask: MaskField = InputField(description="A mask to add to the conditioning tensor.")
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
def invoke(self, context: InvocationContext) -> ConditioningOutput:
self.conditioning.mask = self.mask
self.conditioning.positive_cross_attn_mask_score = self.positive_cross_attn_mask_score
return ConditioningOutput(conditioning=self.conditioning)
@invocation( @invocation(

View File

@ -427,10 +427,6 @@ class ConditioningOutput(BaseInvocationOutput):
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond) conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "ConditioningOutput":
return cls(conditioning=ConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_collection_output") @invocation_output("conditioning_collection_output")
class ConditioningCollectionOutput(BaseInvocationOutput): class ConditioningCollectionOutput(BaseInvocationOutput):

View File

@ -47,7 +47,7 @@ class RegionalPromptData:
# - Scale by region size. # - Scale by region size.
self.negative_cross_attn_mask_score = -10000 self.negative_cross_attn_mask_score = -10000
# self.positive_cross_attn_mask_score = 0.0 # self.positive_cross_attn_mask_score = 0.0
self.positive_self_attn_mask_score = 2.0 self.positive_self_attn_mask_score = 1.0
self.self_attn_mask_end_step_percent = 0.3 self.self_attn_mask_end_step_percent = 0.3
# This one is for regional prompting in general, so should be set on the DenoiseLatents node. # This one is for regional prompting in general, so should be set on the DenoiseLatents node.
self.self_attn_score_range = 3.0 self.self_attn_score_range = 3.0
@ -233,6 +233,7 @@ class RegionalPromptData:
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,) prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len, # Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
# query_seq_len) mask. # query_seq_len) mask.
# TODO(ryand): Is += really the best option here?
attn_mask[batch_idx, :, :] += ( attn_mask[batch_idx, :, :] += (
prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * self.positive_self_attn_mask_score prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * self.positive_self_attn_mask_score
) )