Apply black

This commit is contained in:
Martin Kristiansen
2023-07-27 10:54:01 -04:00
parent 2183dba5c5
commit 218b6d0546
148 changed files with 5486 additions and 6296 deletions

View File

@ -17,6 +17,7 @@ from torch import nn
import invokeai.backend.util.logging as logger
from ...util import torch_dtype
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
@ -55,9 +56,7 @@ class Context:
if name in self.self_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
self.self_cross_attention_module_identifiers.append(name)
for name, module in get_cross_attention_modules(
model, CrossAttentionType.TOKENS
):
for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
if name in self.tokens_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
self.tokens_cross_attention_module_identifiers.append(name)
@ -68,9 +67,7 @@ class Context:
else:
self.tokens_cross_attention_action = Context.Action.SAVE
def request_apply_saved_attention_maps(
self, cross_attention_type: CrossAttentionType
):
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.APPLY
else:
@ -139,9 +136,7 @@ class Context:
saved_attention_dict = self.saved_cross_attention_maps[identifier]
if requested_dim is None:
if saved_attention_dict["dim"] is not None:
raise RuntimeError(
f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}"
)
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
return saved_attention_dict["slices"][0]
if saved_attention_dict["dim"] == requested_dim:
@ -154,21 +149,13 @@ class Context:
if saved_attention_dict["dim"] is None:
whole_saved_attention = saved_attention_dict["slices"][0]
if requested_dim == 0:
return whole_saved_attention[
requested_offset : requested_offset + slice_size
]
return whole_saved_attention[requested_offset : requested_offset + slice_size]
elif requested_dim == 1:
return whole_saved_attention[
:, requested_offset : requested_offset + slice_size
]
return whole_saved_attention[:, requested_offset : requested_offset + slice_size]
raise RuntimeError(
f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}"
)
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
def get_slicing_strategy(
self, identifier: str
) -> tuple[Optional[int], Optional[int]]:
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
if saved_attention is None:
return None, None
@ -201,9 +188,7 @@ class InvokeAICrossAttentionMixin:
def set_attention_slice_wrangler(
self,
wrangler: Optional[
Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]
],
wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]],
):
"""
Set custom attention calculator to be called when attention is calculated
@ -219,14 +204,10 @@ class InvokeAICrossAttentionMixin:
"""
self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(
self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]
):
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]):
self.slicing_strategy_getter = getter
def set_attention_slice_calculated_callback(
self, callback: Optional[Callable[[torch.Tensor], None]]
):
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
self.attention_slice_calculated_callback = callback
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
@ -247,45 +228,31 @@ class InvokeAICrossAttentionMixin:
)
# calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(
dim=-1, dtype=attention_scores.dtype
)
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(
self, default_attention_slice, dim, offset, slice_size
)
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
if self.attention_slice_calculated_callback is not None:
self.attention_slice_calculated_callback(
attention_slice, dim, offset, slice_size
)
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
hidden_states = torch.bmm(attention_slice, value)
return hidden_states
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(
q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype
)
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_lowest_level(
q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size
)
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(
q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype
)
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(
q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size
)
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
@ -353,6 +320,7 @@ def restore_default_cross_attention(
else:
remove_attention_function(model)
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -372,7 +340,7 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
@ -386,16 +354,14 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
else:
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
slice_size = next(
(p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size
)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
def get_cross_attention_modules(
model, which: CrossAttentionType
) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
cross_attention_class: type = (
InvokeAIDiffusersCrossAttention
)
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
cross_attention_class: type = InvokeAIDiffusersCrossAttention
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
attention_module_tuples = [
(name, module)
@ -420,9 +386,7 @@ def get_cross_attention_modules(
def inject_attention_function(unet, context: Context):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(
module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size
):
def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size):
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
attention_slice = suggested_attention_slice
@ -430,9 +394,7 @@ def inject_attention_function(unet, context: Context):
if context.get_should_save_maps(module.identifier):
# print(module.identifier, "saving suggested_attention_slice of shape",
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
slice_to_save = (
attention_slice.to("cpu") if dim is not None else attention_slice
)
slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice
context.save_slice(
module.identifier,
slice_to_save,
@ -442,31 +404,20 @@ def inject_attention_function(unet, context: Context):
)
elif context.get_should_apply_saved_maps(module.identifier):
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
saved_attention_slice = context.get_slice(
module.identifier, dim, offset, slice_size
)
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
# slice may have been offloaded to CPU
saved_attention_slice = saved_attention_slice.to(
suggested_attention_slice.device
)
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
if context.is_tokens_cross_attention(module.identifier):
index_map = context.cross_attention_index_map
remapped_saved_attention_slice = torch.index_select(
saved_attention_slice, -1, index_map
)
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
this_attention_slice = suggested_attention_slice
mask = context.cross_attention_mask.to(
torch_dtype(suggested_attention_slice.device)
)
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
saved_mask = mask
this_mask = 1 - mask
attention_slice = (
remapped_saved_attention_slice * saved_mask
+ this_attention_slice * this_mask
)
attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask
else:
# just use everything
attention_slice = saved_attention_slice
@ -480,14 +431,10 @@ def inject_attention_function(unet, context: Context):
module.identifier = identifier
try:
module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(
lambda module: context.get_slicing_strategy(identifier)
)
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier))
except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(
f"TODO: implement set_attention_slice_wrangler for {type(module)}"
) # TODO
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
else:
raise
@ -503,9 +450,7 @@ def remove_attention_function(unet):
module.set_slicing_strategy_getter(None)
except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(
f"TODO: implement set_attention_slice_wrangler for {type(module)}"
)
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
else:
raise
@ -530,9 +475,7 @@ def get_mem_free_total(device):
return mem_free_total
class InvokeAIDiffusersCrossAttention(
diffusers.models.attention.Attention, InvokeAICrossAttentionMixin
):
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
InvokeAICrossAttentionMixin.__init__(self)
@ -641,11 +584,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
# kwargs
swap_cross_attn_context: SwapCrossAttnContext = None,
):
attention_type = (
CrossAttentionType.SELF
if encoder_hidden_states is None
else CrossAttentionType.TOKENS
)
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
# if cross-attention control is not in play, just call through to the base implementation.
if (
@ -654,9 +593,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
or not swap_cross_attn_context.wants_cross_attention_control(attention_type)
):
# print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
return super().__call__(
attn, hidden_states, encoder_hidden_states, attention_mask
)
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
# else:
# print(f"SwapCrossAttnContext for {attention_type} active")
@ -699,18 +636,10 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
query_slice = query[start_idx:end_idx]
original_key_slice = original_text_key[start_idx:end_idx]
modified_key_slice = modified_text_key[start_idx:end_idx]
attn_mask_slice = (
attention_mask[start_idx:end_idx]
if attention_mask is not None
else None
)
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
original_attn_slice = attn.get_attention_scores(
query_slice, original_key_slice, attn_mask_slice
)
modified_attn_slice = attn.get_attention_scores(
query_slice, modified_key_slice, attn_mask_slice
)
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
# because the prompt modifications may result in token sequences shifted forwards or backwards,
# the original attention probabilities must be remapped to account for token index changes in the
@ -722,9 +651,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
mask = swap_cross_attn_context.mask
inverse_mask = 1 - mask
attn_slice = (
remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
)
attn_slice = remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
del remapped_original_attn_slice, modified_attn_slice
@ -744,6 +671,4 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
def __init__(self):
super(SwapCrossAttnProcessor, self).__init__(
slice_size=int(1e9)
) # massive slice size = don't slice
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice

View File

@ -59,9 +59,7 @@ class AttentionMapSaver:
for key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens
# but we want [N, H, W]
this_scale_factor = math.sqrt(
maps.shape[0] / (latents_width * latents_height)
)
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
this_maps_height = int(float(latents_height) * this_scale_factor)
this_maps_width = int(float(latents_width) * this_scale_factor)
# and we need to do some dimension juggling
@ -72,9 +70,7 @@ class AttentionMapSaver:
# scale to output size if necessary
if this_scale_factor != 1:
maps = tv_resize(
maps, [latents_height, latents_width], InterpolationMode.BICUBIC
)
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
# normalize
maps_min = torch.min(maps)
@ -83,9 +79,7 @@ class AttentionMapSaver:
maps_normalized = (maps - maps_min) / maps_range
# expand to (-0.1, 1.1) and clamp
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
maps_normalized_expanded_clamped = torch.clamp(
maps_normalized_expanded, 0, 1
)
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
# merge together, producing a vertical stack
maps_stacked = torch.reshape(

View File

@ -31,6 +31,7 @@ ModelForwardCallback: TypeAlias = Union[
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
@ -81,14 +82,12 @@ class InvokeAIDiffuserComponent:
@contextmanager
def custom_attention_context(
cls,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int
step_count: int,
):
old_attn_processors = None
if extra_conditioning_info and (
extra_conditioning_info.wants_cross_attention_control
):
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control:
@ -116,27 +115,15 @@ class InvokeAIDiffuserComponent:
return
saver.add_attention_maps(slice, key)
tokens_cross_attention_modules = get_cross_attention_modules(
self.model, CrossAttentionType.TOKENS
)
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for identifier, module in tokens_cross_attention_modules:
key = (
"down"
if identifier.startswith("down")
else "up"
if identifier.startswith("up")
else "mid"
)
key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid"
module.set_attention_slice_calculated_callback(
lambda slice, dim, offset, slice_size, key=key: callback(
slice, dim, offset, slice_size, key
)
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)
)
def remove_attention_map_saving(self):
tokens_cross_attention_modules = get_cross_attention_modules(
self.model, CrossAttentionType.TOKENS
)
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None)
@ -171,10 +158,8 @@ class InvokeAIDiffuserComponent:
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = (
context.get_active_cross_attention_control_types_for_step(
percent_through
)
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(
percent_through
)
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
@ -182,7 +167,11 @@ class InvokeAIDiffuserComponent:
if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
x, sigma, unconditioning, conditioning, **kwargs,
x,
sigma,
unconditioning,
conditioning,
**kwargs,
)
elif wants_cross_attention_control:
(
@ -201,7 +190,11 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning, **kwargs,
x,
sigma,
unconditioning,
conditioning,
**kwargs,
)
else:
@ -209,12 +202,18 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning, **kwargs,
x,
sigma,
unconditioning,
conditioning,
**kwargs,
)
combined_next_x = self._combine(
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
unconditioned_next_x, conditioned_next_x, guidance_scale
unconditioned_next_x,
conditioned_next_x,
guidance_scale,
)
return combined_next_x
@ -229,37 +228,47 @@ class InvokeAIDiffuserComponent:
) -> torch.Tensor:
if postprocessing_settings is not None:
percent_through = step_index / total_step_count
latents = self.apply_threshold(
postprocessing_settings, latents, percent_through
)
latents = self.apply_symmetry(
postprocessing_settings, latents, percent_through
)
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
return latents
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
conditioning_attention_mask = torch.ones(
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
)
if cond.shape[1] < max_len:
conditioning_attention_mask = torch.cat([
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
], dim=1)
conditioning_attention_mask = torch.cat(
[
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
],
dim=1,
)
cond = torch.cat([
cond,
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
], dim=1)
cond = torch.cat(
[
cond,
torch.zeros(
(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
device=cond.device,
dtype=cond.dtype,
),
],
dim=1,
)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([
encoder_attention_mask,
conditioning_attention_mask,
])
encoder_attention_mask = torch.cat(
[
encoder_attention_mask,
conditioning_attention_mask,
]
)
return cond, encoder_attention_mask
encoder_attention_mask = None
@ -277,11 +286,11 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
unconditioning, conditioning
)
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(unconditioning, conditioning)
both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings,
x_twice,
sigma_twice,
both_conditionings,
encoder_attention_mask=encoder_attention_mask,
**kwargs,
)
@ -312,13 +321,17 @@ class InvokeAIDiffuserComponent:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
unconditioned_next_x = self.model_forward_callback(
x, sigma, unconditioning,
x,
sigma,
unconditioning,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
**kwargs,
)
conditioned_next_x = self.model_forward_callback(
x, sigma, conditioning,
x,
sigma,
conditioning,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
**kwargs,
@ -335,13 +348,15 @@ class InvokeAIDiffuserComponent:
for k in conditioning:
if isinstance(conditioning[k], list):
both_conditionings[k] = [
torch.cat([unconditioning[k][i], conditioning[k][i]])
for i in range(len(conditioning[k]))
torch.cat([unconditioning[k][i], conditioning[k][i]]) for i in range(len(conditioning[k]))
]
else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings, **kwargs,
x_twice,
sigma_twice,
both_conditionings,
**kwargs,
).chunk(2)
return unconditioned_next_x, conditioned_next_x
@ -388,9 +403,7 @@ class InvokeAIDiffuserComponent:
)
# do requested cross attention types for conditioning (positive prompt)
cross_attn_processor_context.cross_attention_types_to_do = (
cross_attention_control_types_to_do
)
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
conditioned_next_x = self.model_forward_callback(
x,
sigma,
@ -414,19 +427,14 @@ class InvokeAIDiffuserComponent:
latents: torch.Tensor,
percent_through: float,
) -> torch.Tensor:
if (
postprocessing_settings.threshold is None
or postprocessing_settings.threshold == 0.0
):
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
return latents
threshold = postprocessing_settings.threshold
warmup = postprocessing_settings.warmup
if percent_through < warmup:
current_threshold = threshold + threshold * 5 * (
1 - (percent_through / warmup)
)
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
else:
current_threshold = threshold
@ -440,18 +448,10 @@ class InvokeAIDiffuserComponent:
if self.debug_thresholding:
std, mean = [i.item() for i in torch.std_mean(latents)]
outside = torch.count_nonzero(
(latents < -current_threshold) | (latents > current_threshold)
)
logger.info(
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
)
logger.debug(
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
)
logger.debug(
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
)
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
logger.info(f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})")
logger.debug(f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}")
logger.debug(f"{outside / latents.numel() * 100:.2f}% values outside threshold")
if maxval < current_threshold and minval > -current_threshold:
return latents
@ -464,25 +464,17 @@ class InvokeAIDiffuserComponent:
latents = torch.clone(latents)
maxval = np.clip(maxval * scale, 1, current_threshold)
num_altered += torch.count_nonzero(latents > maxval)
latents[latents > maxval] = (
torch.rand_like(latents[latents > maxval]) * maxval
)
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
if minval < -current_threshold:
latents = torch.clone(latents)
minval = np.clip(minval * scale, -current_threshold, -1)
num_altered += torch.count_nonzero(latents < minval)
latents[latents < minval] = (
torch.rand_like(latents[latents < minval]) * minval
)
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
if self.debug_thresholding:
logger.debug(
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
)
logger.debug(
f"{num_altered / latents.numel() * 100:.2f}% values altered"
)
logger.debug(f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})")
logger.debug(f"{num_altered / latents.numel() * 100:.2f}% values altered")
return latents
@ -501,15 +493,11 @@ class InvokeAIDiffuserComponent:
# Check for out of bounds
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
if h_symmetry_time_pct is not None and (
h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0
):
if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0):
h_symmetry_time_pct = None
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
if v_symmetry_time_pct is not None and (
v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0
):
if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0):
v_symmetry_time_pct = None
dev = latents.device.type
@ -554,9 +542,7 @@ class InvokeAIDiffuserComponent:
def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None:
# percent_through will never reach 1.0 (but this is intended)
return float(step_index) / float(
self.cross_attention_control_context.step_count
)
return float(step_index) / float(self.cross_attention_control_context.step_count)
# find the best possible index of the current sigma in the sigma sequence
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
@ -567,19 +553,13 @@ class InvokeAIDiffuserComponent:
# todo: make this work
@classmethod
def apply_conjunction(
cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale
):
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) # aka sigmas
deltas = None
uncond_latents = None
weighted_cond_list = (
c_or_weighted_c_list
if type(c_or_weighted_c_list) is list
else [(c_or_weighted_c_list, 1)]
)
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
# below is fugly omg
conditionings = [uc] + [c for c, weight in weighted_cond_list]
@ -608,15 +588,11 @@ class InvokeAIDiffuserComponent:
deltas = torch.cat((deltas, latents_b - uncond_latents))
# merge the weighted deltas together into a single merged delta
per_delta_weights = torch.tensor(
weights[1:], dtype=deltas.dtype, device=deltas.device
)
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
normalize = False
if normalize:
per_delta_weights /= torch.sum(per_delta_weights)
reshaped_weights = per_delta_weights.reshape(
per_delta_weights.shape + (1, 1, 1)
)
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)