mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for mask weights, and only mask the tokens associated with the prompts (not eh entire 77-token embedding).
This commit is contained in:
parent
969982b789
commit
ff950bc5cd
@ -44,7 +44,7 @@ from .model import ClipField
|
||||
title="Prompt",
|
||||
tags=["prompt", "compel"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@ -62,9 +62,7 @@ class CompelInvocation(BaseInvocation):
|
||||
mask: Optional[MaskField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
|
||||
positive_self_attn_mask_score: float = InputField(default=1.0, description="")
|
||||
self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="")
|
||||
mask_weight: float = InputField(default=1.0, description="")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@ -136,9 +134,7 @@ class CompelInvocation(BaseInvocation):
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
mask=self.mask,
|
||||
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
|
||||
positive_self_attn_mask_score=self.positive_self_attn_mask_score,
|
||||
self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent,
|
||||
mask_weight=self.mask_weight,
|
||||
)
|
||||
)
|
||||
|
||||
@ -254,7 +250,7 @@ class SDXLPromptInvocationBase:
|
||||
title="SDXL Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@ -281,9 +277,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
mask: Optional[MaskField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
|
||||
positive_self_attn_mask_score: float = InputField(default=1.0, description="")
|
||||
self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="")
|
||||
mask_weight: float = InputField(default=1.0, description="")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@ -350,9 +344,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
mask=self.mask,
|
||||
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
|
||||
positive_self_attn_mask_score=self.positive_self_attn_mask_score,
|
||||
self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent,
|
||||
mask_weight=self.mask_weight,
|
||||
)
|
||||
)
|
||||
|
||||
@ -403,7 +395,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name))
|
||||
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name, mask_weight=1.0))
|
||||
|
||||
|
||||
@invocation_output("clip_skip_output")
|
||||
|
@ -236,13 +236,7 @@ class ConditioningField(BaseModel):
|
||||
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||
"included regions should be set to True.",
|
||||
)
|
||||
positive_cross_attn_mask_score: float = Field(
|
||||
default=0.0,
|
||||
# TODO(ryand): Add more details to this description
|
||||
description="The weight of this conditioning tensor's mask relative to overlapping masks.",
|
||||
)
|
||||
positive_self_attn_mask_score: float = Field(default=1.0, description="")
|
||||
self_attn_adjustment_end_step_percent: float = Field(default=0.0, description="")
|
||||
mask_weight: float = Field(description="")
|
||||
|
||||
|
||||
class MetadataField(RootModel):
|
||||
|
@ -466,9 +466,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
text_embedding.append(text_embedding_info.embeds)
|
||||
if not all_masks_are_none:
|
||||
# embedding_ranges.append(
|
||||
# Range(
|
||||
# start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||
# )
|
||||
# )
|
||||
# HACK(ryand): Contrary to its name, tokens_count_including_eos_bos does not seem to include eos and bos
|
||||
# in the count.
|
||||
embedding_ranges.append(
|
||||
Range(
|
||||
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||
start=cur_text_embedding_len + 1,
|
||||
end=cur_text_embedding_len
|
||||
+ text_embedding_info.extra_conditioning.tokens_count_including_eos_bos,
|
||||
)
|
||||
)
|
||||
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
|
||||
@ -483,11 +492,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
regions = TextConditioningRegions(
|
||||
masks=torch.cat(processed_masks, dim=1),
|
||||
ranges=embedding_ranges,
|
||||
positive_cross_attn_mask_scores=[x.positive_cross_attn_mask_score for x in conditioning_fields],
|
||||
positive_self_attn_mask_scores=[x.positive_self_attn_mask_score for x in conditioning_fields],
|
||||
self_attn_adjustment_end_step_percents=[
|
||||
x.self_attn_adjustment_end_step_percent for x in conditioning_fields
|
||||
],
|
||||
mask_weights=[x.mask_weight for x in conditioning_fields],
|
||||
)
|
||||
|
||||
if extra_conditioning is not None and len(text_conditionings) > 1:
|
||||
|
@ -74,9 +74,7 @@ class TextConditioningRegions:
|
||||
self,
|
||||
masks: torch.Tensor,
|
||||
ranges: list[Range],
|
||||
positive_cross_attn_mask_scores: list[float],
|
||||
positive_self_attn_mask_scores: list[float],
|
||||
self_attn_adjustment_end_step_percents: list[float],
|
||||
mask_weights: list[float],
|
||||
):
|
||||
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
||||
# Shape: (1, num_prompts, height, width)
|
||||
@ -87,19 +85,9 @@ class TextConditioningRegions:
|
||||
# ranges[i] contains the embedding range for the i'th prompt / mask.
|
||||
self.ranges = ranges
|
||||
|
||||
# The following fields control the cross-attention and self-attention handling of region masks. They are indexed
|
||||
# by prompt / mask index.
|
||||
self.positive_cross_attn_mask_scores = positive_cross_attn_mask_scores
|
||||
self.positive_self_attn_mask_scores = positive_self_attn_mask_scores
|
||||
self.self_attn_adjustment_end_step_percents = self_attn_adjustment_end_step_percents
|
||||
self.mask_weights = mask_weights
|
||||
|
||||
assert (
|
||||
self.masks.shape[1]
|
||||
== len(self.ranges)
|
||||
== len(self.positive_cross_attn_mask_scores)
|
||||
== len(self.positive_self_attn_mask_scores)
|
||||
== len(self.self_attn_adjustment_end_step_percents)
|
||||
)
|
||||
assert self.masks.shape[1] == len(self.ranges) == len(self.mask_weights)
|
||||
|
||||
|
||||
class TextConditioningData:
|
||||
|
@ -33,6 +33,7 @@ class RegionalPromptData:
|
||||
regions, max_downscale_factor
|
||||
)
|
||||
self._negative_cross_attn_mask_score = 0.0
|
||||
self._size_weight = 1.0
|
||||
|
||||
def _prepare_spatial_masks(
|
||||
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
||||
@ -96,13 +97,16 @@ class RegionalPromptData:
|
||||
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||
|
||||
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone()
|
||||
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :]
|
||||
size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel()
|
||||
size = size.to(dtype=batch_sample_query_scores.dtype)
|
||||
batch_sample_query_mask = batch_sample_query_scores > 0.5
|
||||
batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size)
|
||||
batch_sample_query_scores[~batch_sample_query_mask] = 0.0
|
||||
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
|
||||
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
|
||||
# size = size.to(dtype=batch_sample_query_scores.dtype)
|
||||
# batch_sample_query_mask = batch_sample_query_scores > 0.5
|
||||
# batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size)
|
||||
# batch_sample_query_scores[~batch_sample_query_mask] = 0.0
|
||||
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores * (
|
||||
mask_weight + self._size_weight * (1 - size)
|
||||
)
|
||||
|
||||
return attn_mask
|
||||
|
||||
@ -136,11 +140,15 @@ class RegionalPromptData:
|
||||
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
|
||||
size = prompt_query_mask.sum() / prompt_query_mask.numel()
|
||||
size = size.to(dtype=prompt_query_mask.dtype)
|
||||
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
|
||||
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
|
||||
# query_seq_len) mask.
|
||||
# TODO(ryand): Is += really the best option here?
|
||||
attn_mask[batch_idx, :, :] += (
|
||||
prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * (1 - size)
|
||||
# TODO(ryand): Is += really the best option here? Maybe elementwise max is better?
|
||||
attn_mask[batch_idx, :, :] = torch.maximum(
|
||||
attn_mask[batch_idx, :, :],
|
||||
prompt_query_mask.unsqueeze(0)
|
||||
* prompt_query_mask.unsqueeze(1)
|
||||
* (mask_weight + self._size_weight * (1 - size)),
|
||||
)
|
||||
|
||||
# if attn_mask[batch_idx].max() < 0.01:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
@ -345,9 +346,7 @@ class InvokeAIDiffuserComponent:
|
||||
r = TextConditioningRegions(
|
||||
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
|
||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||
positive_cross_attn_mask_scores=[0.0],
|
||||
positive_self_attn_mask_scores=[0.0],
|
||||
self_attn_adjustment_end_step_percents=[0.0],
|
||||
mask_weights=[0.0],
|
||||
)
|
||||
regions.append(r)
|
||||
|
||||
@ -355,6 +354,7 @@ class InvokeAIDiffuserComponent:
|
||||
regions=regions, device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = percent_through
|
||||
time.sleep(1.0)
|
||||
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice,
|
||||
|
Loading…
Reference in New Issue
Block a user