mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove AddConditioningMaskInvocaton.
This commit is contained in:
parent
271f8f2414
commit
d313e5eb70
@ -5,7 +5,15 @@ from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
MaskField,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
@ -36,7 +44,7 @@ from .model import ClipField
|
||||
title="Prompt",
|
||||
tags=["prompt", "compel"],
|
||||
category="conditioning",
|
||||
version="1.0.1",
|
||||
version="1.1.0",
|
||||
)
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@ -51,6 +59,10 @@ class CompelInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
mask: Optional[MaskField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@ -118,7 +130,13 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
mask=self.mask,
|
||||
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SDXLPromptInvocationBase:
|
||||
@ -232,7 +250,7 @@ class SDXLPromptInvocationBase:
|
||||
title="SDXL Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.0.1",
|
||||
version="1.1.0",
|
||||
)
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@ -256,6 +274,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
|
||||
mask: Optional[MaskField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||
@ -317,7 +340,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
mask=self.mask,
|
||||
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -366,7 +395,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("clip_skip_output")
|
||||
|
@ -6,27 +6,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
)
|
||||
from invokeai.app.invocations.fields import InputField, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, MaskField, MaskOutput
|
||||
|
||||
|
||||
@invocation(
|
||||
"add_conditioning_mask",
|
||||
title="Add Conditioning Mask",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class AddConditioningMaskInvocation(BaseInvocation):
|
||||
"""Add a mask to an existing conditioning tensor."""
|
||||
|
||||
conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
|
||||
mask: MaskField = InputField(description="A mask to add to the conditioning tensor.")
|
||||
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
self.conditioning.mask = self.mask
|
||||
self.conditioning.positive_cross_attn_mask_score = self.positive_cross_attn_mask_score
|
||||
return ConditioningOutput(conditioning=self.conditioning)
|
||||
from invokeai.app.invocations.primitives import MaskField, MaskOutput
|
||||
|
||||
|
||||
@invocation(
|
||||
|
@ -427,10 +427,6 @@ class ConditioningOutput(BaseInvocationOutput):
|
||||
|
||||
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "ConditioningOutput":
|
||||
return cls(conditioning=ConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_collection_output")
|
||||
class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
|
@ -47,7 +47,7 @@ class RegionalPromptData:
|
||||
# - Scale by region size.
|
||||
self.negative_cross_attn_mask_score = -10000
|
||||
# self.positive_cross_attn_mask_score = 0.0
|
||||
self.positive_self_attn_mask_score = 2.0
|
||||
self.positive_self_attn_mask_score = 1.0
|
||||
self.self_attn_mask_end_step_percent = 0.3
|
||||
# This one is for regional prompting in general, so should be set on the DenoiseLatents node.
|
||||
self.self_attn_score_range = 3.0
|
||||
@ -233,6 +233,7 @@ class RegionalPromptData:
|
||||
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
|
||||
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
|
||||
# query_seq_len) mask.
|
||||
# TODO(ryand): Is += really the best option here?
|
||||
attn_mask[batch_idx, :, :] += (
|
||||
prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * self.positive_self_attn_mask_score
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user