mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for mask weights, and only mask the tokens associated with the prompts (not eh entire 77-token embedding).
This commit is contained in:
parent
969982b789
commit
ff950bc5cd
@ -44,7 +44,7 @@ from .model import ClipField
|
|||||||
title="Prompt",
|
title="Prompt",
|
||||||
tags=["prompt", "compel"],
|
tags=["prompt", "compel"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.1.0",
|
version="1.2.0",
|
||||||
)
|
)
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -62,9 +62,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
mask: Optional[MaskField] = InputField(
|
mask: Optional[MaskField] = InputField(
|
||||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||||
)
|
)
|
||||||
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
|
mask_weight: float = InputField(default=1.0, description="")
|
||||||
positive_self_attn_mask_score: float = InputField(default=1.0, description="")
|
|
||||||
self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
@ -136,9 +134,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
conditioning=ConditioningField(
|
conditioning=ConditioningField(
|
||||||
conditioning_name=conditioning_name,
|
conditioning_name=conditioning_name,
|
||||||
mask=self.mask,
|
mask=self.mask,
|
||||||
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
|
mask_weight=self.mask_weight,
|
||||||
positive_self_attn_mask_score=self.positive_self_attn_mask_score,
|
|
||||||
self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -254,7 +250,7 @@ class SDXLPromptInvocationBase:
|
|||||||
title="SDXL Prompt",
|
title="SDXL Prompt",
|
||||||
tags=["sdxl", "compel", "prompt"],
|
tags=["sdxl", "compel", "prompt"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.1.0",
|
version="1.2.0",
|
||||||
)
|
)
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -281,9 +277,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
mask: Optional[MaskField] = InputField(
|
mask: Optional[MaskField] = InputField(
|
||||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||||
)
|
)
|
||||||
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
|
mask_weight: float = InputField(default=1.0, description="")
|
||||||
positive_self_attn_mask_score: float = InputField(default=1.0, description="")
|
|
||||||
self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
@ -350,9 +344,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
conditioning=ConditioningField(
|
conditioning=ConditioningField(
|
||||||
conditioning_name=conditioning_name,
|
conditioning_name=conditioning_name,
|
||||||
mask=self.mask,
|
mask=self.mask,
|
||||||
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
|
mask_weight=self.mask_weight,
|
||||||
positive_self_attn_mask_score=self.positive_self_attn_mask_score,
|
|
||||||
self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -403,7 +395,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
|
||||||
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name))
|
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name, mask_weight=1.0))
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("clip_skip_output")
|
@invocation_output("clip_skip_output")
|
||||||
|
@ -236,13 +236,7 @@ class ConditioningField(BaseModel):
|
|||||||
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||||
"included regions should be set to True.",
|
"included regions should be set to True.",
|
||||||
)
|
)
|
||||||
positive_cross_attn_mask_score: float = Field(
|
mask_weight: float = Field(description="")
|
||||||
default=0.0,
|
|
||||||
# TODO(ryand): Add more details to this description
|
|
||||||
description="The weight of this conditioning tensor's mask relative to overlapping masks.",
|
|
||||||
)
|
|
||||||
positive_self_attn_mask_score: float = Field(default=1.0, description="")
|
|
||||||
self_attn_adjustment_end_step_percent: float = Field(default=0.0, description="")
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataField(RootModel):
|
class MetadataField(RootModel):
|
||||||
|
@ -466,9 +466,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
text_embedding.append(text_embedding_info.embeds)
|
text_embedding.append(text_embedding_info.embeds)
|
||||||
if not all_masks_are_none:
|
if not all_masks_are_none:
|
||||||
|
# embedding_ranges.append(
|
||||||
|
# Range(
|
||||||
|
# start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# HACK(ryand): Contrary to its name, tokens_count_including_eos_bos does not seem to include eos and bos
|
||||||
|
# in the count.
|
||||||
embedding_ranges.append(
|
embedding_ranges.append(
|
||||||
Range(
|
Range(
|
||||||
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
start=cur_text_embedding_len + 1,
|
||||||
|
end=cur_text_embedding_len
|
||||||
|
+ text_embedding_info.extra_conditioning.tokens_count_including_eos_bos,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
|
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
|
||||||
@ -483,11 +492,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
regions = TextConditioningRegions(
|
regions = TextConditioningRegions(
|
||||||
masks=torch.cat(processed_masks, dim=1),
|
masks=torch.cat(processed_masks, dim=1),
|
||||||
ranges=embedding_ranges,
|
ranges=embedding_ranges,
|
||||||
positive_cross_attn_mask_scores=[x.positive_cross_attn_mask_score for x in conditioning_fields],
|
mask_weights=[x.mask_weight for x in conditioning_fields],
|
||||||
positive_self_attn_mask_scores=[x.positive_self_attn_mask_score for x in conditioning_fields],
|
|
||||||
self_attn_adjustment_end_step_percents=[
|
|
||||||
x.self_attn_adjustment_end_step_percent for x in conditioning_fields
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if extra_conditioning is not None and len(text_conditionings) > 1:
|
if extra_conditioning is not None and len(text_conditionings) > 1:
|
||||||
|
@ -74,9 +74,7 @@ class TextConditioningRegions:
|
|||||||
self,
|
self,
|
||||||
masks: torch.Tensor,
|
masks: torch.Tensor,
|
||||||
ranges: list[Range],
|
ranges: list[Range],
|
||||||
positive_cross_attn_mask_scores: list[float],
|
mask_weights: list[float],
|
||||||
positive_self_attn_mask_scores: list[float],
|
|
||||||
self_attn_adjustment_end_step_percents: list[float],
|
|
||||||
):
|
):
|
||||||
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
||||||
# Shape: (1, num_prompts, height, width)
|
# Shape: (1, num_prompts, height, width)
|
||||||
@ -87,19 +85,9 @@ class TextConditioningRegions:
|
|||||||
# ranges[i] contains the embedding range for the i'th prompt / mask.
|
# ranges[i] contains the embedding range for the i'th prompt / mask.
|
||||||
self.ranges = ranges
|
self.ranges = ranges
|
||||||
|
|
||||||
# The following fields control the cross-attention and self-attention handling of region masks. They are indexed
|
self.mask_weights = mask_weights
|
||||||
# by prompt / mask index.
|
|
||||||
self.positive_cross_attn_mask_scores = positive_cross_attn_mask_scores
|
|
||||||
self.positive_self_attn_mask_scores = positive_self_attn_mask_scores
|
|
||||||
self.self_attn_adjustment_end_step_percents = self_attn_adjustment_end_step_percents
|
|
||||||
|
|
||||||
assert (
|
assert self.masks.shape[1] == len(self.ranges) == len(self.mask_weights)
|
||||||
self.masks.shape[1]
|
|
||||||
== len(self.ranges)
|
|
||||||
== len(self.positive_cross_attn_mask_scores)
|
|
||||||
== len(self.positive_self_attn_mask_scores)
|
|
||||||
== len(self.self_attn_adjustment_end_step_percents)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TextConditioningData:
|
class TextConditioningData:
|
||||||
|
@ -33,6 +33,7 @@ class RegionalPromptData:
|
|||||||
regions, max_downscale_factor
|
regions, max_downscale_factor
|
||||||
)
|
)
|
||||||
self._negative_cross_attn_mask_score = 0.0
|
self._negative_cross_attn_mask_score = 0.0
|
||||||
|
self._size_weight = 1.0
|
||||||
|
|
||||||
def _prepare_spatial_masks(
|
def _prepare_spatial_masks(
|
||||||
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
||||||
@ -96,13 +97,16 @@ class RegionalPromptData:
|
|||||||
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||||
|
|
||||||
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||||
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone()
|
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :]
|
||||||
size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel()
|
size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel()
|
||||||
size = size.to(dtype=batch_sample_query_scores.dtype)
|
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
|
||||||
batch_sample_query_mask = batch_sample_query_scores > 0.5
|
# size = size.to(dtype=batch_sample_query_scores.dtype)
|
||||||
batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size)
|
# batch_sample_query_mask = batch_sample_query_scores > 0.5
|
||||||
batch_sample_query_scores[~batch_sample_query_mask] = 0.0
|
# batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size)
|
||||||
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
|
# batch_sample_query_scores[~batch_sample_query_mask] = 0.0
|
||||||
|
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores * (
|
||||||
|
mask_weight + self._size_weight * (1 - size)
|
||||||
|
)
|
||||||
|
|
||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
@ -136,11 +140,15 @@ class RegionalPromptData:
|
|||||||
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
|
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
|
||||||
size = prompt_query_mask.sum() / prompt_query_mask.numel()
|
size = prompt_query_mask.sum() / prompt_query_mask.numel()
|
||||||
size = size.to(dtype=prompt_query_mask.dtype)
|
size = size.to(dtype=prompt_query_mask.dtype)
|
||||||
|
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
|
||||||
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
|
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
|
||||||
# query_seq_len) mask.
|
# query_seq_len) mask.
|
||||||
# TODO(ryand): Is += really the best option here?
|
# TODO(ryand): Is += really the best option here? Maybe elementwise max is better?
|
||||||
attn_mask[batch_idx, :, :] += (
|
attn_mask[batch_idx, :, :] = torch.maximum(
|
||||||
prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * (1 - size)
|
attn_mask[batch_idx, :, :],
|
||||||
|
prompt_query_mask.unsqueeze(0)
|
||||||
|
* prompt_query_mask.unsqueeze(1)
|
||||||
|
* (mask_weight + self._size_weight * (1 - size)),
|
||||||
)
|
)
|
||||||
|
|
||||||
# if attn_mask[batch_idx].max() < 0.01:
|
# if attn_mask[batch_idx].max() < 0.01:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
@ -345,9 +346,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
r = TextConditioningRegions(
|
r = TextConditioningRegions(
|
||||||
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
|
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
|
||||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||||
positive_cross_attn_mask_scores=[0.0],
|
mask_weights=[0.0],
|
||||||
positive_self_attn_mask_scores=[0.0],
|
|
||||||
self_attn_adjustment_end_step_percents=[0.0],
|
|
||||||
)
|
)
|
||||||
regions.append(r)
|
regions.append(r)
|
||||||
|
|
||||||
@ -355,6 +354,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
regions=regions, device=x.device, dtype=x.dtype
|
regions=regions, device=x.device, dtype=x.dtype
|
||||||
)
|
)
|
||||||
cross_attention_kwargs["percent_through"] = percent_through
|
cross_attention_kwargs["percent_through"] = percent_through
|
||||||
|
time.sleep(1.0)
|
||||||
|
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice,
|
x_twice,
|
||||||
|
Loading…
Reference in New Issue
Block a user