Save and display per-token attention maps (#1866)

* attention maps saving to /tmp

* tidy up diffusers branch backporting of cross attention refactoring

* base64-encoding the attention maps image for generationResult

* cleanup/refactor conditioning.py

* attention maps and tokens being sent to web UI

* attention maps: restrict count to actual token count and improve robustness

* add argument type hint to image_to_dataURL function

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>

Co-authored-by: damian <git@damianstewart.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Damian Stewart 2022-12-10 15:57:41 +01:00 committed by GitHub
parent 55132f6463
commit 786b8878d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 636 additions and 346 deletions

View File

@ -18,9 +18,11 @@ from PIL.Image import Image as ImageType
from uuid import uuid4 from uuid import uuid4
from threading import Event from threading import Event
from ldm.generate import Generate
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
from ldm.invoke.prompt_parser import split_weighted_subprompts from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend
from ldm.invoke.generator.inpaint import infill_methods from ldm.invoke.generator.inpaint import infill_methods
from backend.modules.parameters import parameters_to_command from backend.modules.parameters import parameters_to_command
@ -39,7 +41,7 @@ if not os.path.isabs(args.outdir):
class InvokeAIWebServer: class InvokeAIWebServer:
def __init__(self, generate, gfpgan, codeformer, esrgan) -> None: def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None:
self.host = args.host self.host = args.host
self.port = args.port self.port = args.port
@ -905,16 +907,13 @@ class InvokeAIWebServer:
}, },
) )
if generation_parameters["progress_latents"]: if generation_parameters["progress_latents"]:
image = self.generate.sample_to_lowres_estimated_image(sample) image = self.generate.sample_to_lowres_estimated_image(sample)
(width, height) = image.size (width, height) = image.size
width *= 8 width *= 8
height *= 8 height *= 8
buffered = io.BytesIO() img_base64 = image_to_dataURL(image)
image.save(buffered, format="PNG")
img_base64 = "data:image/png;base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
self.socketio.emit( self.socketio.emit(
"intermediateResult", "intermediateResult",
{ {
@ -932,7 +931,7 @@ class InvokeAIWebServer:
self.socketio.emit("progressUpdate", progress.to_formatted_dict()) self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0) eventlet.sleep(0)
def image_done(image, seed, first_seed): def image_done(image, seed, first_seed, attention_maps_image=None):
if self.canceled.is_set(): if self.canceled.is_set():
raise CanceledException raise CanceledException
@ -1094,6 +1093,12 @@ class InvokeAIWebServer:
self.socketio.emit("progressUpdate", progress.to_formatted_dict()) self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0) eventlet.sleep(0)
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
tokens = None if type(parsed_prompt) is Blend else \
get_tokens_for_prompt(self.generate.model, parsed_prompt)
attention_maps_image_base64_url = None if attention_maps_image is None \
else image_to_dataURL(attention_maps_image)
self.socketio.emit( self.socketio.emit(
"generationResult", "generationResult",
{ {
@ -1106,6 +1111,8 @@ class InvokeAIWebServer:
"height": height, "height": height,
"boundingBox": original_bounding_box, "boundingBox": original_bounding_box,
"generationMode": generation_parameters["generation_mode"], "generationMode": generation_parameters["generation_mode"],
"attentionMaps": attention_maps_image_base64_url,
"tokens": tokens,
}, },
) )
eventlet.sleep(0) eventlet.sleep(0)
@ -1117,7 +1124,7 @@ class InvokeAIWebServer:
self.generate.prompt2image( self.generate.prompt2image(
**generation_parameters, **generation_parameters,
step_callback=image_progress, step_callback=image_progress,
image_callback=image_done, image_callback=image_done
) )
except KeyboardInterrupt: except KeyboardInterrupt:
@ -1564,6 +1571,19 @@ def dataURL_to_image(dataURL: str) -> ImageType:
) )
return image return image
"""
Converts an image into a base64 image dataURL.
"""
def image_to_dataURL(image: ImageType) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_base64 = "data:image/png;base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
return image_base64
""" """
Converts a base64 image dataURL into bytes. Converts a base64 image dataURL into bytes.

View File

@ -20,6 +20,8 @@ import cv2
import skimage import skimage
from omegaconf import OmegaConf from omegaconf import OmegaConf
import ldm.invoke.conditioning
from ldm.invoke.generator.base import downsampling from ldm.invoke.generator.base import downsampling
from PIL import Image, ImageOps from PIL import Image, ImageOps
from torch import nn from torch import nn
@ -455,7 +457,7 @@ class Generate:
try: try:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec( uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model, prompt, model =self.model,
skip_normalize=skip_normalize, skip_normalize_legacy_blend=skip_normalize,
log_tokens =self.log_tokenization log_tokens =self.log_tokenization
) )
@ -607,8 +609,8 @@ class Generate:
# todo: cross-attention control # todo: cross-attention control
uc, c, extra_conditioning_info = get_uc_and_c_and_ec( uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model, prompt, model =self.model,
skip_normalize=opt.skip_normalize, skip_normalize_legacy_blend=opt.skip_normalize,
log_tokens =opt.log_tokenization log_tokens =ldm.invoke.conditioning.log_tokenization
) )
if tool in ('gfpgan','codeformer','upscale'): if tool in ('gfpgan','codeformer','upscale'):

View File

@ -8,6 +8,7 @@ import time
import traceback import traceback
import yaml import yaml
from ldm.generate import Generate
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
from ldm.invoke.prompt_parser import PromptParser from ldm.invoke.prompt_parser import PromptParser
from ldm.invoke.readline import get_completer, Completer from ldm.invoke.readline import get_completer, Completer
@ -281,7 +282,7 @@ def main_loop(gen, opt):
prefix = file_writer.unique_prefix() prefix = file_writer.unique_prefix()
step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None): def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None, attention_maps_image=None):
# note the seed is the seed of the current image # note the seed is the seed of the current image
# the first_seed is the original seed that noise is added to # the first_seed is the original seed that noise is added to
# when the -v switch is used to generate variations # when the -v switch is used to generate variations
@ -789,7 +790,7 @@ def get_next_command(infile=None) -> str: # command string
print(f'#{command}') print(f'#{command}')
return command return command
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan): def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan):
print('\n* --web was specified, starting web server...') print('\n* --web was specified, starting web server...')
from backend.invoke_ai_web_server import InvokeAIWebServer from backend.invoke_ai_web_server import InvokeAIWebServer
# Change working directory to the stable-diffusion directory # Change working directory to the stable-diffusion directory

View File

@ -7,20 +7,46 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
''' '''
import re import re
from difflib import SequenceMatcher
from typing import Union from typing import Union
import torch import torch
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment
from ..models.diffusion import cross_attention_control from ..models.diffusion import cross_attention_control
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False): def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
prompt, negative_prompt = get_prompt_structure(prompt_string,
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
return conditioning
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> (
Union[FlattenedPrompt, Blend], FlattenedPrompt):
"""
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
"""
prompt, negative_prompt = _parse_prompt_string(prompt_string,
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
return prompt, negative_prompt
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]:
text_fragments = [x.text if type(x) is Fragment else
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
str(x))
for x in parsed_prompt.children]
text = " ".join(text_fragments)
tokens = model.cond_stage_model.tokenizer.tokenize(text)
return tokens
def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]:
# Extract Unconditioned Words From Prompt # Extract Unconditioned Words From Prompt
unconditioned_words = '' unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]' unconditional_regex = r'\[(.*?)\]'
@ -39,7 +65,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
pp = PromptParser() pp = PromptParser()
parsed_prompt: Union[FlattenedPrompt, Blend] = None parsed_prompt: Union[FlattenedPrompt, Blend] = None
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned) legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned, skip_normalize_legacy_blend)
if legacy_blend is not None: if legacy_blend is not None:
parsed_prompt = legacy_blend parsed_prompt = legacy_blend
else: else:
@ -47,31 +73,62 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0] parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0] parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
return parsed_prompt, parsed_negative_prompt
def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt,
model, log_tokens=False) \
-> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]:
"""
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
"""
if log_tokens: if log_tokens:
print(f">> Parsed prompt to {parsed_prompt}") print(f">> Parsed prompt to {parsed_prompt}")
print(f">> Parsed negative prompt to {parsed_negative_prompt}") print(f">> Parsed negative prompt to {parsed_negative_prompt}")
conditioning = None conditioning = None
cac_args:cross_attention_control.Arguments = None cac_args: cross_attention_control.Arguments = None
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt conditioning = _get_conditioning_for_blend(model, parsed_prompt, log_tokens)
embeddings_to_blend = None elif type(parsed_prompt) is FlattenedPrompt:
for i,flattened_prompt in enumerate(blend.prompts): if parsed_prompt.wants_cross_attention_control:
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, conditioning, cac_args = _get_conditioning_for_cross_attention_control(model, parsed_prompt, log_tokens)
flattened_prompt,
log_tokens=log_tokens,
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
blend.weights,
normalize=blend.normalize_weights)
else: else:
flattened_prompt: FlattenedPrompt = parsed_prompt conditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
wants_cross_attention_control = type(flattened_prompt) is not Blend \ parsed_prompt,
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children]) log_tokens=log_tokens,
if wants_cross_attention_control: log_display_label="(prompt)")
else:
raise ValueError(f"parsed_prompt is '{type(parsed_prompt)}' which is not a supported prompt type")
unconditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
parsed_negative_prompt,
log_tokens=log_tokens,
log_display_label="(unconditioning)")
if isinstance(conditioning, dict):
# hybrid conditioning is in play
unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning)
if cac_args is not None:
print(
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
cac_args = None
eos_token_index = 1
if type(parsed_prompt) is not Blend:
tokens = get_tokens_for_prompt(model, parsed_prompt)
eos_token_index = len(tokens)+1
return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=eos_token_index + 1,
cross_attention_control_args=cac_args
)
)
def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True):
original_prompt = FlattenedPrompt() original_prompt = FlattenedPrompt()
edited_prompt = FlattenedPrompt() edited_prompt = FlattenedPrompt()
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed # for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
@ -80,17 +137,18 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
edit_options = [] edit_options = []
edit_opcodes = [] edit_opcodes = []
# beginning of sequence # beginning of sequence
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1)) edit_opcodes.append(
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
edit_options.append(None) edit_options.append(None)
original_token_count += 1 original_token_count += 1
edited_token_count += 1 edited_token_count += 1
for fragment in flattened_prompt.children: for fragment in prompt.children:
if type(fragment) is CrossAttentionControlSubstitute: if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original) original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited) edited_prompt.append(fragment.edited)
to_replace_token_count = get_tokens_length(model, fragment.original) to_replace_token_count = _get_tokens_length(model, fragment.original)
replacement_token_count = get_tokens_length(model, fragment.edited) replacement_token_count = _get_tokens_length(model, fragment.edited)
edit_opcodes.append(('replace', edit_opcodes.append(('replace',
original_token_count, original_token_count + to_replace_token_count, original_token_count, original_token_count + to_replace_token_count,
edited_token_count, edited_token_count + replacement_token_count edited_token_count, edited_token_count + replacement_token_count
@ -98,25 +156,26 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
original_token_count += to_replace_token_count original_token_count += to_replace_token_count
edited_token_count += replacement_token_count edited_token_count += replacement_token_count
edit_options.append(fragment.options) edit_options.append(fragment.options)
#elif type(fragment) is CrossAttentionControlAppend: # elif type(fragment) is CrossAttentionControlAppend:
# edited_prompt.append(fragment.fragment) # edited_prompt.append(fragment.fragment)
else: else:
# regular fragment # regular fragment
original_prompt.append(fragment) original_prompt.append(fragment)
edited_prompt.append(fragment) edited_prompt.append(fragment)
count = get_tokens_length(model, [fragment]) count = _get_tokens_length(model, [fragment])
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count)) edit_opcodes.append(('equal', original_token_count, original_token_count + count, edited_token_count,
edited_token_count + count))
edit_options.append(None) edit_options.append(None)
original_token_count += count original_token_count += count
edited_token_count += count edited_token_count += count
# end of sequence # end of sequence
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1)) edit_opcodes.append(
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
edit_options.append(None) edit_options.append(None)
original_token_count += 1 original_token_count += 1
edited_token_count += 1 edited_token_count += 1
original_embeddings, original_tokens = _get_embeddings_and_tokens_for_prompt(model,
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
original_prompt, original_prompt,
log_tokens=log_tokens, log_tokens=log_tokens,
log_display_label="(.swap originals)") log_display_label="(.swap originals)")
@ -126,50 +185,38 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra # 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit. # token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions # todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_embeddings, edited_tokens = _get_embeddings_and_tokens_for_prompt(model,
edited_prompt, edited_prompt,
log_tokens=log_tokens, log_tokens=log_tokens,
log_display_label="(.swap replacements)") log_display_label="(.swap replacements)")
conditioning = original_embeddings conditioning = original_embeddings
edited_conditioning = edited_embeddings edited_conditioning = edited_embeddings
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options) # print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = cross_attention_control.Arguments( cac_args = cross_attention_control.Arguments(
edited_conditioning = edited_conditioning, edited_conditioning=edited_conditioning,
edit_opcodes = edit_opcodes, edit_opcodes=edit_opcodes,
edit_options = edit_options edit_options=edit_options
) )
else: return conditioning, cac_args
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
embeddings_to_blend = None
for i, flattened_prompt in enumerate(blend.prompts):
this_embedding, _ = _get_embeddings_and_tokens_for_prompt(model,
flattened_prompt, flattened_prompt,
log_tokens=log_tokens, log_tokens=log_tokens,
log_display_label="(prompt)") log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, (embeddings_to_blend, this_embedding))
parsed_negative_prompt, conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
log_tokens=log_tokens, blend.weights,
log_display_label="(unconditioning)") normalize=blend.normalize_weights)
if isinstance(conditioning, dict): return conditioning
# hybrid conditioning is in play
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
if cac_args is not None:
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
cac_args = None
return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
cross_attention_control_args=cac_args
)
)
def build_token_edit_opcodes(original_tokens, edited_tokens): def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False,
original_tokens = original_tokens.cpu().numpy()[0] log_display_label: str = None):
edited_tokens = edited_tokens.cpu().numpy()[0]
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None):
if type(flattened_prompt) is not FlattenedPrompt: if type(flattened_prompt) is not FlattenedPrompt:
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead") raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
fragments = [x.text for x in flattened_prompt.children] fragments = [x.text for x in flattened_prompt.children]
@ -181,12 +228,14 @@ def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: Fl
return embeddings, tokens return embeddings, tokens
def get_tokens_length(model, fragments: list[Fragment]):
def _get_tokens_length(model, fragments: list[Fragment]):
fragment_texts = [x.text for x in fragments] fragment_texts = [x.text for x in fragments]
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False) tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
return sum([len(x) for x in tokens]) return sum([len(x) for x in tokens])
def flatten_hybrid_conditioning(uncond, cond):
def _flatten_hybrid_conditioning(uncond, cond):
''' '''
This handles the choice between a conditional conditioning This handles the choice between a conditional conditioning
that is a tensor (used by cross attention) vs one that has additional that is a tensor (used by cross attention) vs one that has additional
@ -206,3 +255,28 @@ def flatten_hybrid_conditioning(uncond, cond):
return uncond, cond_flattened return uncond, cond_flattened
def log_tokenization(text, model, display_label=None):
""" shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
"""
tokens = model.cond_stage_model.tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens - usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -14,6 +14,7 @@ import cv2 as cv
from einops import rearrange, repeat from einops import rearrange, repeat
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from ldm.invoke.devices import choose_autocast from ldm.invoke.devices import choose_autocast
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ldm.util import rand_perlin_2d from ldm.util import rand_perlin_2d
downsampling = 8 downsampling = 8
@ -51,9 +52,12 @@ class Generator():
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None, safety_checker:dict=None,
attention_maps_callback = None,
**kwargs): **kwargs):
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
self.safety_checker = safety_checker self.safety_checker = safety_checker
attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
make_image = self.get_make_image( make_image = self.get_make_image(
prompt, prompt,
sampler = sampler, sampler = sampler,
@ -63,6 +67,7 @@ class Generator():
step_callback = step_callback, step_callback = step_callback,
threshold = threshold, threshold = threshold,
perlin = perlin, perlin = perlin,
attention_maps_callback = attention_maps_callback,
**kwargs **kwargs
) )
results = [] results = []
@ -98,7 +103,7 @@ class Generator():
results.append([image, seed]) results.append([image, seed])
if image_callback is not None: if image_callback is not None:
image_callback(image, seed, first_seed=first_seed) image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_images[-1])
seed = self.new_seed() seed = self.new_seed()

View File

@ -14,7 +14,9 @@ class Txt2Img(Generator):
@torch.no_grad() @torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs): conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
attention_maps_callback=None,
**kwargs):
""" """
Returns a function returning an image derived from the prompt and the initial image Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it Return value depends on the seed at the time you call it
@ -49,6 +51,7 @@ class Txt2Img(Generator):
eta = ddim_eta, eta = ddim_eta,
img_callback = step_callback, img_callback = step_callback,
threshold = threshold, threshold = threshold,
attention_maps_callback = attention_maps_callback,
) )
if self.free_gpu_mem: if self.free_gpu_mem:

View File

@ -69,6 +69,12 @@ class FlattenedPrompt():
return len(self.children) == 0 or \ return len(self.children) == 0 or \
(len(self.children) == 1 and len(self.children[0].text) == 0) (len(self.children) == 1 and len(self.children[0].text) == 0)
@property
def wants_cross_attention_control(self):
return any(
[issubclass(type(x), CrossAttentionControlledFragment) for x in self.children]
)
def __repr__(self): def __repr__(self):
return f"FlattenedPrompt:{self.children}" return f"FlattenedPrompt:{self.children}"
def __eq__(self, other): def __eq__(self, other):
@ -240,6 +246,12 @@ class Blend():
self.weights = weights self.weights = weights
self.normalize_weights = normalize_weights self.normalize_weights = normalize_weights
@property
def wants_cross_attention_control(self):
# blends cannot cross-attention control
return False
def __repr__(self): def __repr__(self):
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}" return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
def __eq__(self, other): def __eq__(self, other):
@ -277,8 +289,8 @@ class PromptParser():
return self.flatten(root[0]) return self.flatten(root[0])
def parse_legacy_blend(self, text: str) -> Optional[Blend]: def parse_legacy_blend(self, text: str, skip_normalize: bool) -> Optional[Blend]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False) weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1: if len(weighted_subprompts) <= 1:
return None return None
strings = [x[0] for x in weighted_subprompts] strings = [x[0] for x in weighted_subprompts]
@ -287,7 +299,7 @@ class PromptParser():
parsed_conjunctions = [self.parse_conjunction(x) for x in strings] parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True) return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
def flatten(self, root: Conjunction, verbose = False) -> Conjunction: def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
@ -641,27 +653,3 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
return [(x[0], equal_weight) for x in parsed_prompts] return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts] return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, display_label=None):
tokens = model.cond_stage_model.tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -53,7 +53,6 @@ COMMANDS = (
'--codeformer_fidelity','-cf', '--codeformer_fidelity','-cf',
'--upscale','-U', '--upscale','-U',
'-save_orig','--save_original', '-save_orig','--save_original',
'--skip_normalize','-x',
'--log_tokenization','-t', '--log_tokenization','-t',
'--hires_fix', '--hires_fix',
'--inpaint_replace','-r', '--inpaint_replace','-r',

View File

@ -1,12 +1,14 @@
import enum import enum
from typing import Optional import math
from typing import Optional, Callable
import psutil
import torch import torch
from torch import nn
# adapted from bloc97's CrossAttentionControl colab # adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl # https://github.com/bloc97/CrossAttentionControl
class Arguments: class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict): def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
""" """
@ -63,9 +65,13 @@ class Context:
self.clear_requests(cleanup=True) self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model): def register_cross_attention_modules(self, model):
for name,module in get_attention_modules(model, CrossAttentionType.SELF): for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF):
if name in self.self_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
self.self_cross_attention_module_identifiers.append(name) self.self_cross_attention_module_identifiers.append(name)
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS): for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
if name in self.tokens_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
self.tokens_cross_attention_module_identifiers.append(name) self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType): def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
@ -166,6 +172,135 @@ class Context:
map_dict[offset] = slice.to('cpu') map_dict[offset] = slice.to('cpu')
class InvokeAICrossAttentionMixin:
"""
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
and dymamic slicing strategy selection.
"""
def __init__(self):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.attention_slice_wrangler = None
self.slicing_strategy_getter = None
self.attention_slice_calculated_callback = None
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
'''
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current CrossAttention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
Pass None to use the default attention calculation.
:return:
'''
self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
self.slicing_strategy_getter = getter
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
self.attention_slice_calculated_callback = callback
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
# calculate attention scores
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
# calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
if self.attention_slice_calculated_callback is not None:
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
hidden_states = torch.bmm(attention_slice, value)
return hidden_states
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
slicing_strategy_getter = self.slicing_strategy_getter
if slicing_strategy_getter is not None:
(dim, slice_size) = slicing_strategy_getter(self)
if dim is not None:
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
if dim == 0:
return self.einsum_op_slice_dim0(q, k, v, slice_size)
elif dim == 1:
return self.einsum_op_slice_dim1(q, k, v, slice_size)
# fallback for when there is no saved strategy, or saved strategy does not slice
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_invokeai_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda':
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v)
if q.device.type == 'mps' or q.device.type == 'cpu':
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
def remove_cross_attention_control(model): def remove_cross_attention_control(model):
remove_attention_function(model) remove_attention_function(model)
@ -187,7 +322,7 @@ def setup_cross_attention_control(model, context: Context):
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention # mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length) mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long) indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long) indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length: if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
@ -201,10 +336,23 @@ def setup_cross_attention_control(model, context: Context):
inject_attention_function(model, context) inject_attention_function(model, context)
def get_attention_modules(model, which: CrossAttentionType): def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
cross_attention_class: type = InvokeAICrossAttentionMixin
# cross_attention_class: type = InvokeAIDiffusersCrossAttention
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
return [(name,module) for name, module in model.named_modules() if attention_module_tuples = [(name,module) for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name] isinstance(module, cross_attention_class) and which_attn in name]
cross_attention_modules_in_model_count = len(attention_module_tuples)
expected_count = 16
if cross_attention_modules_in_model_count != expected_count:
# non-fatal error but .swap() won't work.
print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " +
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " +
f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " +
f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " +
f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " +
f"work properly until it is fixed.")
return attention_module_tuples
def inject_attention_function(unet, context: Context): def inject_attention_function(unet, context: Context):
@ -244,19 +392,52 @@ def inject_attention_function(unet, context: Context):
return attention_slice return attention_slice
for name, module in unet.named_modules(): cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
module_name = type(module).__name__ for identifier, module in cross_attention_modules:
if module_name == "CrossAttention": module.identifier = identifier
module.identifier = name try:
module.set_attention_slice_wrangler(attention_slice_wrangler) module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \ module.set_slicing_strategy_getter(
context.get_slicing_strategy(module_identifier)) lambda module: context.get_slicing_strategy(identifier)
)
except AttributeError as e:
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
else:
raise
def remove_attention_function(unet): def remove_attention_function(unet):
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for identifier, module in cross_attention_modules:
try:
# clear wrangler callback # clear wrangler callback
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None) module.set_attention_slice_wrangler(None)
module.set_slicing_strategy_getter(None) module.set_slicing_strategy_getter(None)
except AttributeError as e:
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
else:
raise
def is_attribute_error_about(error: AttributeError, attribute: str):
if hasattr(error, 'name'): # Python 3.10
return error.name == attribute
else: # Python 3.9
return attribute in str(error)
def get_mem_free_total(device):
#only on cuda
if not torch.cuda.is_available():
return None
stats = torch.cuda.memory_stats(device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
return mem_free_total

View File

@ -0,0 +1,95 @@
import math
import PIL
import torch
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode
from ldm.models.diffusion.cross_attention_control import get_cross_attention_modules, CrossAttentionType
class AttentionMapSaver():
def __init__(self, token_ids: range, latents_shape: torch.Size):
self.token_ids = token_ids
self.latents_shape = latents_shape
#self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
self.collated_maps = {}
def clear_maps(self):
self.collated_maps = {}
def add_attention_maps(self, maps: torch.Tensor, key: str):
"""
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
:return: None
"""
key_and_size = f'{key}_{maps.shape[1]}'
# extract desired tokens
maps = maps[:, :, self.token_ids]
# merge attention heads to a single map per token
maps = torch.sum(maps, 0)
# store
if key_and_size not in self.collated_maps:
self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu')
self.collated_maps[key_and_size] += maps.cpu()
def write_maps_to_disk(self, path: str):
pil_image = self.get_stacked_maps_image()
pil_image.save(path, 'PNG')
def get_stacked_maps_image(self) -> PIL.Image:
"""
Scale all collected attention maps to the same size, blend them together and return as an image.
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
"""
num_tokens = len(self.token_ids)
if num_tokens == 0:
return None
latents_height = self.latents_shape[0]
latents_width = self.latents_shape[1]
merged = None
for key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens
# but we want [N, H, W]
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
this_maps_height = int(float(latents_height) * this_scale_factor)
this_maps_width = int(float(latents_width) * this_scale_factor)
# and we need to do some dimension juggling
maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width])
# scale to output size if necessary
if this_scale_factor != 1:
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
# normalize
maps_min = torch.min(maps)
maps_range = torch.max(maps) - maps_min
#print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
maps_normalized = (maps - maps_min) / maps_range
# expand to (-0.1, 1.1) and clamp
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
# merge together, producing a vertical stack
maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width])
if merged is None:
merged = maps_stacked
else:
# screen blend
merged = 1 - (1 - maps_stacked)*(1 - merged)
if merged is None:
return None
merged_bytes = merged.mul(0xff).byte()
return PIL.Image.fromarray(merged_bytes.numpy(), mode='L')

View File

@ -4,6 +4,7 @@ import k_diffusion as K
import torch import torch
from torch import nn from torch import nn
from .cross_attention_map_saving import AttentionMapSaver
from .sampler import Sampler from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
@ -36,6 +37,7 @@ class CFGDenoiser(nn.Module):
self.invokeai_diffuser = InvokeAIDiffuserComponent(model, self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond)) model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
def prepare_to_sample(self, t_enc, **kwargs): def prepare_to_sample(self, t_enc, **kwargs):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None) extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
@ -158,6 +160,7 @@ class KSampler(Sampler):
callback=None, callback=None,
normals_sequence=None, normals_sequence=None,
img_callback=None, img_callback=None,
attention_maps_callback=None,
quantize_x0=False, quantize_x0=False,
eta=0.0, eta=0.0,
mask=None, mask=None,
@ -171,7 +174,7 @@ class KSampler(Sampler):
log_every_t=100, log_every_t=100,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
extra_conditioning_info=None, extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
threshold = 0, threshold = 0,
perlin = 0, perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -204,6 +207,12 @@ class KSampler(Sampler):
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info) model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
attention_map_token_ids = range(1, extra_conditioning_info.tokens_count_including_eos_bos - 1)
attention_maps_saver = None if attention_maps_callback is None else AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
if attention_maps_callback is not None:
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
extra_args = { extra_args = {
'cond': conditioning, 'cond': conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,
@ -217,6 +226,8 @@ class KSampler(Sampler):
), ),
None, None,
) )
if attention_maps_callback is not None:
attention_maps_callback(attention_maps_saver)
return sampling_result return sampling_result
# this code will support inpainting if and when ksampler API modified or # this code will support inpainting if and when ksampler API modified or

View File

@ -5,8 +5,8 @@ from typing import Callable, Optional, Union
import torch import torch
from ldm.models.diffusion.cross_attention_control import Arguments, \ from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, CrossAttentionType
from ldm.modules.attention import get_mem_free_total from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
class InvokeAIDiffuserComponent: class InvokeAIDiffuserComponent:
@ -21,7 +21,8 @@ class InvokeAIDiffuserComponent:
class ExtraConditioningInfo: class ExtraConditioningInfo:
def __init__(self, cross_attention_control_args: Optional[Arguments]): def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]):
self.tokens_count_including_eos_bos = tokens_count_including_eos_bos
self.cross_attention_control_args = cross_attention_control_args self.cross_attention_control_args = cross_attention_control_args
@property @property
@ -52,7 +53,25 @@ class InvokeAIDiffuserComponent:
self.cross_attention_control_context = None self.cross_attention_control_context = None
remove_cross_attention_control(self.model) remove_cross_attention_control(self.model)
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
if dim is not None:
# sliced tokens attention map saving is not implemented
return
saver.add_attention_maps(slice, key)
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for identifier, module in tokens_cross_attention_modules:
key = ('down' if identifier.startswith('down') else
'up' if identifier.startswith('up') else
'mid')
module.set_attention_slice_calculated_callback(
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key))
def remove_attention_map_saving(self):
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None)
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: Union[torch.Tensor,dict], unconditioning: Union[torch.Tensor,dict],

View File

@ -7,10 +7,9 @@ import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin
from ldm.modules.diffusionmodules.util import checkpoint from ldm.modules.diffusionmodules.util import checkpoint
import psutil
def exists(val): def exists(val):
return val is not None return val is not None
@ -164,9 +163,10 @@ def get_mem_free_total(device):
return mem_free_total return mem_free_total
class CrossAttention(nn.Module): class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__() super().__init__()
InvokeAICrossAttentionMixin.__init__(self)
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -182,118 +182,6 @@ class CrossAttention(nn.Module):
nn.Dropout(dropout) nn.Dropout(dropout)
) )
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.cached_mem_free_total = None
self.attention_slice_wrangler = None
self.slicing_strategy_getter = None
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
'''
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current CrossAttention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
Pass None to use the default attention calculation.
:return:
'''
self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
self.slicing_strategy_getter = getter
def cache_free_memory_count(self, device):
self.cached_mem_free_total = get_mem_free_total(device)
print("free cuda memory: ", self.cached_mem_free_total)
def clear_cached_free_memory_count(self):
self.cached_mem_free_total = None
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
# calculate attention scores
attention_scores = einsum('b i d, b j d -> b i j', q, k)
# calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
return einsum('b i j, b j d -> b i d', attention_slice, v)
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
slicing_strategy_getter = self.slicing_strategy_getter
if slicing_strategy_getter is not None:
(dim, slice_size) = slicing_strategy_getter(self)
if dim is not None:
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
if dim == 0:
return self.einsum_op_slice_dim0(q, k, v, slice_size)
elif dim == 1:
return self.einsum_op_slice_dim1(q, k, v, slice_size)
# fallback for when there is no saved strategy, or saved strategy does not slice
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda':
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v)
if q.device.type == 'mps':
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
@ -305,7 +193,11 @@ class CrossAttention(nn.Module):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = self.get_attention_mem_efficient(q, k, v) # don't apply scale twice
cached_scale = self.scale
self.scale = 1
r = self.get_invokeai_attention_mem_efficient(q, k, v)
self.scale = cached_scale
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h) hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
return self.to_out(hidden_states) return self.to_out(hidden_states)