mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove AddConditioningMaskInvocaton.
This commit is contained in:
parent
271f8f2414
commit
d313e5eb70
@ -5,7 +5,15 @@ from compel import Compel, ReturnedEmbeddingsType
|
|||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
from invokeai.app.invocations.fields import (
|
||||||
|
ConditioningField,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
MaskField,
|
||||||
|
OutputField,
|
||||||
|
UIComponent,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.ti_utils import generate_ti_list
|
from invokeai.app.util.ti_utils import generate_ti_list
|
||||||
@ -36,7 +44,7 @@ from .model import ClipField
|
|||||||
title="Prompt",
|
title="Prompt",
|
||||||
tags=["prompt", "compel"],
|
tags=["prompt", "compel"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.0.1",
|
version="1.1.0",
|
||||||
)
|
)
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -51,6 +59,10 @@ class CompelInvocation(BaseInvocation):
|
|||||||
description=FieldDescriptions.clip,
|
description=FieldDescriptions.clip,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
|
mask: Optional[MaskField] = InputField(
|
||||||
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||||
|
)
|
||||||
|
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
@ -118,7 +130,13 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
|
||||||
return ConditioningOutput.build(conditioning_name)
|
return ConditioningOutput(
|
||||||
|
conditioning=ConditioningField(
|
||||||
|
conditioning_name=conditioning_name,
|
||||||
|
mask=self.mask,
|
||||||
|
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
@ -232,7 +250,7 @@ class SDXLPromptInvocationBase:
|
|||||||
title="SDXL Prompt",
|
title="SDXL Prompt",
|
||||||
tags=["sdxl", "compel", "prompt"],
|
tags=["sdxl", "compel", "prompt"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.0.1",
|
version="1.1.0",
|
||||||
)
|
)
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -256,6 +274,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||||
|
|
||||||
|
mask: Optional[MaskField] = InputField(
|
||||||
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||||
|
)
|
||||||
|
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||||
@ -317,7 +340,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
|
||||||
return ConditioningOutput.build(conditioning_name)
|
return ConditioningOutput(
|
||||||
|
conditioning=ConditioningField(
|
||||||
|
conditioning_name=conditioning_name,
|
||||||
|
mask=self.mask,
|
||||||
|
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -366,7 +395,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
|
||||||
return ConditioningOutput.build(conditioning_name)
|
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name))
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("clip_skip_output")
|
@invocation_output("clip_skip_output")
|
||||||
|
@ -6,27 +6,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation,
|
invocation,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.fields import InputField, WithMetadata
|
from invokeai.app.invocations.fields import InputField, WithMetadata
|
||||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, MaskField, MaskOutput
|
from invokeai.app.invocations.primitives import MaskField, MaskOutput
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
|
||||||
"add_conditioning_mask",
|
|
||||||
title="Add Conditioning Mask",
|
|
||||||
tags=["conditioning"],
|
|
||||||
category="conditioning",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class AddConditioningMaskInvocation(BaseInvocation):
|
|
||||||
"""Add a mask to an existing conditioning tensor."""
|
|
||||||
|
|
||||||
conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
|
|
||||||
mask: MaskField = InputField(description="A mask to add to the conditioning tensor.")
|
|
||||||
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
|
||||||
self.conditioning.mask = self.mask
|
|
||||||
self.conditioning.positive_cross_attn_mask_score = self.positive_cross_attn_mask_score
|
|
||||||
return ConditioningOutput(conditioning=self.conditioning)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
|
@ -427,10 +427,6 @@ class ConditioningOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build(cls, conditioning_name: str) -> "ConditioningOutput":
|
|
||||||
return cls(conditioning=ConditioningField(conditioning_name=conditioning_name))
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("conditioning_collection_output")
|
@invocation_output("conditioning_collection_output")
|
||||||
class ConditioningCollectionOutput(BaseInvocationOutput):
|
class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||||
|
@ -47,7 +47,7 @@ class RegionalPromptData:
|
|||||||
# - Scale by region size.
|
# - Scale by region size.
|
||||||
self.negative_cross_attn_mask_score = -10000
|
self.negative_cross_attn_mask_score = -10000
|
||||||
# self.positive_cross_attn_mask_score = 0.0
|
# self.positive_cross_attn_mask_score = 0.0
|
||||||
self.positive_self_attn_mask_score = 2.0
|
self.positive_self_attn_mask_score = 1.0
|
||||||
self.self_attn_mask_end_step_percent = 0.3
|
self.self_attn_mask_end_step_percent = 0.3
|
||||||
# This one is for regional prompting in general, so should be set on the DenoiseLatents node.
|
# This one is for regional prompting in general, so should be set on the DenoiseLatents node.
|
||||||
self.self_attn_score_range = 3.0
|
self.self_attn_score_range = 3.0
|
||||||
@ -233,6 +233,7 @@ class RegionalPromptData:
|
|||||||
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
|
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
|
||||||
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
|
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
|
||||||
# query_seq_len) mask.
|
# query_seq_len) mask.
|
||||||
|
# TODO(ryand): Is += really the best option here?
|
||||||
attn_mask[batch_idx, :, :] += (
|
attn_mask[batch_idx, :, :] += (
|
||||||
prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * self.positive_self_attn_mask_score
|
prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * self.positive_self_attn_mask_score
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user