Merge remote-tracking branch 'origin/main' into dev/diffusers

# Conflicts:
#	backend/invoke_ai_web_server.py
#	ldm/generate.py
#	ldm/invoke/CLI.py
#	ldm/invoke/generator/base.py
#	ldm/invoke/generator/txt2img.py
#	ldm/models/diffusion/cross_attention_control.py
#	ldm/modules/attention.py
This commit is contained in:
Kevin Turner 2022-12-10 08:43:37 -08:00
commit 63532226a5
13 changed files with 515 additions and 181 deletions

View File

@ -21,6 +21,7 @@ from backend.modules.get_canvas_generation_mode import (
get_canvas_generation_mode, get_canvas_generation_mode,
) )
from backend.modules.parameters import parameters_to_command from backend.modules.parameters import parameters_to_command
from ldm.generate import Generate
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
from ldm.invoke.generator.inpaint import infill_methods from ldm.invoke.generator.inpaint import infill_methods
@ -38,7 +39,7 @@ if not os.path.isabs(args.outdir):
class InvokeAIWebServer: class InvokeAIWebServer:
def __init__(self, generate, gfpgan, codeformer, esrgan) -> None: def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None:
self.host = args.host self.host = args.host
self.port = args.port self.port = args.port
@ -906,16 +907,13 @@ class InvokeAIWebServer:
}, },
) )
if generation_parameters["progress_latents"]: if generation_parameters["progress_latents"]:
image = self.generate.sample_to_lowres_estimated_image(sample) image = self.generate.sample_to_lowres_estimated_image(sample)
(width, height) = image.size (width, height) = image.size
width *= 8 width *= 8
height *= 8 height *= 8
buffered = io.BytesIO() img_base64 = image_to_dataURL(image)
image.save(buffered, format="PNG")
img_base64 = "data:image/png;base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
self.socketio.emit( self.socketio.emit(
"intermediateResult", "intermediateResult",
{ {
@ -933,7 +931,7 @@ class InvokeAIWebServer:
self.socketio.emit("progressUpdate", progress.to_formatted_dict()) self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0) eventlet.sleep(0)
def image_done(image, seed, first_seed): def image_done(image, seed, first_seed, attention_maps_image=None):
if self.canceled.is_set(): if self.canceled.is_set():
raise CanceledException raise CanceledException
@ -1095,6 +1093,12 @@ class InvokeAIWebServer:
self.socketio.emit("progressUpdate", progress.to_formatted_dict()) self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0) eventlet.sleep(0)
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
tokens = None if type(parsed_prompt) is Blend else \
get_tokens_for_prompt(self.generate.model, parsed_prompt)
attention_maps_image_base64_url = None if attention_maps_image is None \
else image_to_dataURL(attention_maps_image)
self.socketio.emit( self.socketio.emit(
"generationResult", "generationResult",
{ {
@ -1107,6 +1111,8 @@ class InvokeAIWebServer:
"height": height, "height": height,
"boundingBox": original_bounding_box, "boundingBox": original_bounding_box,
"generationMode": generation_parameters["generation_mode"], "generationMode": generation_parameters["generation_mode"],
"attentionMaps": attention_maps_image_base64_url,
"tokens": tokens,
}, },
) )
eventlet.sleep(0) eventlet.sleep(0)
@ -1118,7 +1124,7 @@ class InvokeAIWebServer:
self.generate.prompt2image( self.generate.prompt2image(
**generation_parameters, **generation_parameters,
step_callback=image_progress, step_callback=image_progress,
image_callback=image_done, image_callback=image_done
) )
except KeyboardInterrupt: except KeyboardInterrupt:
@ -1565,6 +1571,19 @@ def dataURL_to_image(dataURL: str) -> ImageType:
) )
return image return image
"""
Converts an image into a base64 image dataURL.
"""
def image_to_dataURL(image: ImageType) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_base64 = "data:image/png;base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
return image_base64
""" """
Converts a base64 image dataURL into bytes. Converts a base64 image dataURL into bytes.

View File

@ -16,7 +16,6 @@ import numpy as np
import skimage import skimage
import torch import torch
import transformers import transformers
from PIL import Image, ImageOps
from diffusers import HeunDiscreteScheduler from diffusers import HeunDiscreteScheduler
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddim import DDIMScheduler
@ -27,6 +26,7 @@ from diffusers.schedulers.scheduling_ipndm import IPNDMScheduler
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
from diffusers.schedulers.scheduling_pndm import PNDMScheduler from diffusers.schedulers.scheduling_pndm import PNDMScheduler
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image, ImageOps
from pytorch_lightning import seed_everything, logging from pytorch_lightning import seed_everything, logging
from ldm.invoke.args import metadata_from_png from ldm.invoke.args import metadata_from_png
@ -461,7 +461,7 @@ class Generate:
try: try:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec( uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model, prompt, model =self.model,
skip_normalize=skip_normalize, skip_normalize_legacy_blend=skip_normalize,
log_tokens =self.log_tokenization log_tokens =self.log_tokenization
) )
@ -613,8 +613,8 @@ class Generate:
# todo: cross-attention control # todo: cross-attention control
uc, c, extra_conditioning_info = get_uc_and_c_and_ec( uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model, prompt, model =self.model,
skip_normalize=opt.skip_normalize, skip_normalize_legacy_blend=opt.skip_normalize,
log_tokens =opt.log_tokenization log_tokens =ldm.invoke.conditioning.log_tokenization
) )
if tool in ('gfpgan','codeformer','upscale'): if tool in ('gfpgan','codeformer','upscale'):

View File

@ -8,6 +8,7 @@ import time
import traceback import traceback
import yaml import yaml
from ldm.generate import Generate
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
from ldm.invoke.prompt_parser import PromptParser from ldm.invoke.prompt_parser import PromptParser
from ldm.invoke.readline import get_completer, Completer from ldm.invoke.readline import get_completer, Completer
@ -282,7 +283,7 @@ def main_loop(gen, opt):
prefix = file_writer.unique_prefix() prefix = file_writer.unique_prefix()
step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None): def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None, attention_maps_image=None):
# note the seed is the seed of the current image # note the seed is the seed of the current image
# the first_seed is the original seed that noise is added to # the first_seed is the original seed that noise is added to
# when the -v switch is used to generate variations # when the -v switch is used to generate variations
@ -790,7 +791,7 @@ def get_next_command(infile=None) -> str: # command string
print(f'#{command}') print(f'#{command}')
return command return command
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan): def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan):
print('\n* --web was specified, starting web server...') print('\n* --web was specified, starting web server...')
from backend.invoke_ai_web_server import InvokeAIWebServer from backend.invoke_ai_web_server import InvokeAIWebServer
# Change working directory to the stable-diffusion directory # Change working directory to the stable-diffusion directory

View File

@ -7,20 +7,46 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
''' '''
import re import re
from difflib import SequenceMatcher
from typing import Union from typing import Union
import torch import torch
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment
from ..models.diffusion import cross_attention_control from ..models.diffusion import cross_attention_control
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False): def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
prompt, negative_prompt = get_prompt_structure(prompt_string,
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
return conditioning
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> (
Union[FlattenedPrompt, Blend], FlattenedPrompt):
"""
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
"""
prompt, negative_prompt = _parse_prompt_string(prompt_string,
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
return prompt, negative_prompt
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]:
text_fragments = [x.text if type(x) is Fragment else
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
str(x))
for x in parsed_prompt.children]
text = " ".join(text_fragments)
tokens = model.cond_stage_model.tokenizer.tokenize(text)
return tokens
def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]:
# Extract Unconditioned Words From Prompt # Extract Unconditioned Words From Prompt
unconditioned_words = '' unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]' unconditional_regex = r'\[(.*?)\]'
@ -39,7 +65,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
pp = PromptParser() pp = PromptParser()
parsed_prompt: Union[FlattenedPrompt, Blend] = None parsed_prompt: Union[FlattenedPrompt, Blend] = None
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned) legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned, skip_normalize_legacy_blend)
if legacy_blend is not None: if legacy_blend is not None:
parsed_prompt = legacy_blend parsed_prompt = legacy_blend
else: else:
@ -47,129 +73,150 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0] parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0] parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
return parsed_prompt, parsed_negative_prompt
def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt,
model, log_tokens=False) \
-> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]:
"""
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
"""
if log_tokens: if log_tokens:
print(f">> Parsed prompt to {parsed_prompt}") print(f">> Parsed prompt to {parsed_prompt}")
print(f">> Parsed negative prompt to {parsed_negative_prompt}") print(f">> Parsed negative prompt to {parsed_negative_prompt}")
conditioning = None conditioning = None
cac_args:cross_attention_control.Arguments = None cac_args: cross_attention_control.Arguments = None
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt conditioning = _get_conditioning_for_blend(model, parsed_prompt, log_tokens)
embeddings_to_blend = None elif type(parsed_prompt) is FlattenedPrompt:
for i,flattened_prompt in enumerate(blend.prompts): if parsed_prompt.wants_cross_attention_control:
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, conditioning, cac_args = _get_conditioning_for_cross_attention_control(model, parsed_prompt, log_tokens)
flattened_prompt,
log_tokens=log_tokens,
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
blend.weights,
normalize=blend.normalize_weights)
else:
flattened_prompt: FlattenedPrompt = parsed_prompt
wants_cross_attention_control = type(flattened_prompt) is not Blend \
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
if wants_cross_attention_control:
original_prompt = FlattenedPrompt()
edited_prompt = FlattenedPrompt()
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
original_token_count = 0
edited_token_count = 0
edit_options = []
edit_opcodes = []
# beginning of sequence
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
for fragment in flattened_prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited)
to_replace_token_count = get_tokens_length(model, fragment.original)
replacement_token_count = get_tokens_length(model, fragment.edited)
edit_opcodes.append(('replace',
original_token_count, original_token_count + to_replace_token_count,
edited_token_count, edited_token_count + replacement_token_count
))
original_token_count += to_replace_token_count
edited_token_count += replacement_token_count
edit_options.append(fragment.options)
#elif type(fragment) is CrossAttentionControlAppend:
# edited_prompt.append(fragment.fragment)
else:
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
count = get_tokens_length(model, [fragment])
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
edit_options.append(None)
original_token_count += count
edited_token_count += count
# end of sequence
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
original_prompt,
log_tokens=log_tokens,
log_display_label="(.swap originals)")
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
edited_prompt,
log_tokens=log_tokens,
log_display_label="(.swap replacements)")
conditioning = original_embeddings
edited_conditioning = edited_embeddings
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = cross_attention_control.Arguments(
edited_conditioning = edited_conditioning,
edit_opcodes = edit_opcodes,
edit_options = edit_options
)
else: else:
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, conditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
flattened_prompt, parsed_prompt,
log_tokens=log_tokens, log_tokens=log_tokens,
log_display_label="(prompt)") log_display_label="(prompt)")
else:
raise ValueError(f"parsed_prompt is '{type(parsed_prompt)}' which is not a supported prompt type")
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, unconditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
parsed_negative_prompt, parsed_negative_prompt,
log_tokens=log_tokens, log_tokens=log_tokens,
log_display_label="(unconditioning)") log_display_label="(unconditioning)")
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
# hybrid conditioning is in play # hybrid conditioning is in play
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning) unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning)
if cac_args is not None: if cac_args is not None:
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.") print(
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
cac_args = None cac_args = None
eos_token_index = 1
if type(parsed_prompt) is not Blend:
tokens = get_tokens_for_prompt(model, parsed_prompt)
eos_token_index = len(tokens)+1
return ( return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo( unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=eos_token_index + 1,
cross_attention_control_args=cac_args cross_attention_control_args=cac_args
) )
) )
def build_token_edit_opcodes(original_tokens, edited_tokens): def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True):
original_tokens = original_tokens.cpu().numpy()[0] original_prompt = FlattenedPrompt()
edited_tokens = edited_tokens.cpu().numpy()[0] edited_prompt = FlattenedPrompt()
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
original_token_count = 0
edited_token_count = 0
edit_options = []
edit_opcodes = []
# beginning of sequence
edit_opcodes.append(
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
for fragment in prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited)
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes() to_replace_token_count = _get_tokens_length(model, fragment.original)
replacement_token_count = _get_tokens_length(model, fragment.edited)
edit_opcodes.append(('replace',
original_token_count, original_token_count + to_replace_token_count,
edited_token_count, edited_token_count + replacement_token_count
))
original_token_count += to_replace_token_count
edited_token_count += replacement_token_count
edit_options.append(fragment.options)
# elif type(fragment) is CrossAttentionControlAppend:
# edited_prompt.append(fragment.fragment)
else:
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None): count = _get_tokens_length(model, [fragment])
edit_opcodes.append(('equal', original_token_count, original_token_count + count, edited_token_count,
edited_token_count + count))
edit_options.append(None)
original_token_count += count
edited_token_count += count
# end of sequence
edit_opcodes.append(
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
original_embeddings, original_tokens = _get_embeddings_and_tokens_for_prompt(model,
original_prompt,
log_tokens=log_tokens,
log_display_label="(.swap originals)")
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = _get_embeddings_and_tokens_for_prompt(model,
edited_prompt,
log_tokens=log_tokens,
log_display_label="(.swap replacements)")
conditioning = original_embeddings
edited_conditioning = edited_embeddings
# print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = cross_attention_control.Arguments(
edited_conditioning=edited_conditioning,
edit_opcodes=edit_opcodes,
edit_options=edit_options
)
return conditioning, cac_args
def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
embeddings_to_blend = None
for i, flattened_prompt in enumerate(blend.prompts):
this_embedding, _ = _get_embeddings_and_tokens_for_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
blend.weights,
normalize=blend.normalize_weights)
return conditioning
def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False,
log_display_label: str = None):
if type(flattened_prompt) is not FlattenedPrompt: if type(flattened_prompt) is not FlattenedPrompt:
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead") raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
fragments = [x.text for x in flattened_prompt.children] fragments = [x.text for x in flattened_prompt.children]
@ -181,12 +228,14 @@ def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: Fl
return embeddings, tokens return embeddings, tokens
def get_tokens_length(model, fragments: list[Fragment]):
def _get_tokens_length(model, fragments: list[Fragment]):
fragment_texts = [x.text for x in fragments] fragment_texts = [x.text for x in fragments]
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False) tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
return sum([len(x) for x in tokens]) return sum([len(x) for x in tokens])
def flatten_hybrid_conditioning(uncond, cond):
def _flatten_hybrid_conditioning(uncond, cond):
''' '''
This handles the choice between a conditional conditioning This handles the choice between a conditional conditioning
that is a tensor (used by cross attention) vs one that has additional that is a tensor (used by cross attention) vs one that has additional
@ -205,4 +254,29 @@ def flatten_hybrid_conditioning(uncond, cond):
cond_flattened[k] = torch.cat([uncond[k], cond[k]]) cond_flattened[k] = torch.cat([uncond[k], cond[k]])
return uncond, cond_flattened return uncond, cond_flattened
def log_tokenization(text, model, display_label=None):
""" shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
"""
tokens = model.cond_stage_model.tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens - usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -19,6 +19,7 @@ from pytorch_lightning import seed_everything
from tqdm import trange from tqdm import trange
from ldm.invoke.devices import choose_autocast from ldm.invoke.devices import choose_autocast
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ldm.models.diffusion.ddpm import DiffusionWrapper from ldm.models.diffusion.ddpm import DiffusionWrapper
from ldm.util import rand_perlin_2d from ldm.util import rand_perlin_2d
@ -62,9 +63,12 @@ class Generator:
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None, safety_checker:dict=None,
attention_maps_callback = None,
**kwargs): **kwargs):
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
self.safety_checker = safety_checker self.safety_checker = safety_checker
attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
make_image = self.get_make_image( make_image = self.get_make_image(
prompt, prompt,
sampler = sampler, sampler = sampler,
@ -74,6 +78,7 @@ class Generator:
step_callback = step_callback, step_callback = step_callback,
threshold = threshold, threshold = threshold,
perlin = perlin, perlin = perlin,
attention_maps_callback = attention_maps_callback,
**kwargs **kwargs
) )
results = [] results = []
@ -109,7 +114,7 @@ class Generator:
results.append([image, seed]) results.append([image, seed])
if image_callback is not None: if image_callback is not None:
image_callback(image, seed, first_seed=first_seed) image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_images[-1])
seed = self.new_seed() seed = self.new_seed()

View File

@ -15,6 +15,7 @@ class Txt2Img(Generator):
@torch.no_grad() @torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0, conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
attention_maps_callback=None,
**kwargs): **kwargs):
""" """
Returns a function returning an image derived from the prompt and the initial image Returns a function returning an image derived from the prompt and the initial image
@ -39,6 +40,7 @@ class Txt2Img(Generator):
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=extra_conditioning_info,
# TODO: eta = ddim_eta, # TODO: eta = ddim_eta,
# TODO: threshold = threshold, # TODO: threshold = threshold,
attention_maps_callback = attention_maps_callback,
) )
return pipeline.numpy_to_pil(pipeline_output.images)[0] return pipeline.numpy_to_pil(pipeline_output.images)[0]

View File

@ -3,7 +3,7 @@ from typing import Union, Optional
import re import re
import pyparsing as pp import pyparsing as pp
''' '''
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors. This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
weighted subprompts. weighted subprompts.
Useful class exports: Useful class exports:
@ -69,6 +69,12 @@ class FlattenedPrompt():
return len(self.children) == 0 or \ return len(self.children) == 0 or \
(len(self.children) == 1 and len(self.children[0].text) == 0) (len(self.children) == 1 and len(self.children[0].text) == 0)
@property
def wants_cross_attention_control(self):
return any(
[issubclass(type(x), CrossAttentionControlledFragment) for x in self.children]
)
def __repr__(self): def __repr__(self):
return f"FlattenedPrompt:{self.children}" return f"FlattenedPrompt:{self.children}"
def __eq__(self, other): def __eq__(self, other):
@ -240,6 +246,12 @@ class Blend():
self.weights = weights self.weights = weights
self.normalize_weights = normalize_weights self.normalize_weights = normalize_weights
@property
def wants_cross_attention_control(self):
# blends cannot cross-attention control
return False
def __repr__(self): def __repr__(self):
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}" return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
def __eq__(self, other): def __eq__(self, other):
@ -277,8 +289,8 @@ class PromptParser():
return self.flatten(root[0]) return self.flatten(root[0])
def parse_legacy_blend(self, text: str) -> Optional[Blend]: def parse_legacy_blend(self, text: str, skip_normalize: bool) -> Optional[Blend]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False) weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1: if len(weighted_subprompts) <= 1:
return None return None
strings = [x[0] for x in weighted_subprompts] strings = [x[0] for x in weighted_subprompts]
@ -287,7 +299,7 @@ class PromptParser():
parsed_conjunctions = [self.parse_conjunction(x) for x in strings] parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True) return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
def flatten(self, root: Conjunction, verbose = False) -> Conjunction: def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
@ -641,27 +653,3 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
return [(x[0], equal_weight) for x in parsed_prompts] return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts] return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, display_label=None):
tokens = model.cond_stage_model.tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -53,7 +53,6 @@ COMMANDS = (
'--codeformer_fidelity','-cf', '--codeformer_fidelity','-cf',
'--upscale','-U', '--upscale','-U',
'-save_orig','--save_original', '-save_orig','--save_original',
'--skip_normalize','-x',
'--log_tokenization','-t', '--log_tokenization','-t',
'--hires_fix', '--hires_fix',
'--inpaint_replace','-r', '--inpaint_replace','-r',
@ -117,19 +116,19 @@ class Completer(object):
# extensions defined, so go directly into path completion mode # extensions defined, so go directly into path completion mode
if self.extensions is not None: if self.extensions is not None:
self.matches = self._path_completions(text, state, self.extensions) self.matches = self._path_completions(text, state, self.extensions)
# looking for an image file # looking for an image file
elif re.search(path_regexp,buffer): elif re.search(path_regexp,buffer):
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer) do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut) self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
# looking for a seed # looking for a seed
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer): elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
self.matches= self._seed_completions(text,state) self.matches= self._seed_completions(text,state)
elif re.search('<[\w-]*$',buffer): elif re.search('<[\w-]*$',buffer):
self.matches= self._concept_completions(text,state) self.matches= self._concept_completions(text,state)
# looking for a model # looking for a model
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer): elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
self.matches= self._model_completions(text, state) self.matches= self._model_completions(text, state)
@ -227,7 +226,7 @@ class Completer(object):
if h_len < 1: if h_len < 1:
print('<empty history>') print('<empty history>')
return return
for i in range(0,h_len): for i in range(0,h_len):
line = self.get_history_item(i+1) line = self.get_history_item(i+1)
if match and match not in line: if match and match not in line:
@ -367,7 +366,7 @@ class DummyCompleter(Completer):
def __init__(self,options): def __init__(self,options):
super().__init__(options) super().__init__(options)
self.history = list() self.history = list()
def add_history(self,line): def add_history(self,line):
self.history.append(line) self.history.append(line)

View File

@ -7,11 +7,9 @@ import torch
import diffusers import diffusers
from torch import nn from torch import nn
# adapted from bloc97's CrossAttentionControl colab # adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl # https://github.com/bloc97/CrossAttentionControl
class Arguments: class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict): def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
""" """
@ -68,11 +66,11 @@ class Context:
self.clear_requests(cleanup=True) self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model): def register_cross_attention_modules(self, model):
for name,module in get_attention_modules(model, CrossAttentionType.SELF): for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF):
if name in self.self_cross_attention_module_identifiers: if name in self.self_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once" assert False, f"name {name} cannot appear more than once"
self.self_cross_attention_module_identifiers.append(name) self.self_cross_attention_module_identifiers.append(name)
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS): for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
if name in self.tokens_cross_attention_module_identifiers: if name in self.tokens_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once" assert False, f"name {name} cannot appear more than once"
self.tokens_cross_attention_module_identifiers.append(name) self.tokens_cross_attention_module_identifiers.append(name)
@ -175,6 +173,135 @@ class Context:
map_dict[offset] = slice.to('cpu') map_dict[offset] = slice.to('cpu')
class InvokeAICrossAttentionMixin:
"""
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
and dymamic slicing strategy selection.
"""
def __init__(self):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.attention_slice_wrangler = None
self.slicing_strategy_getter = None
self.attention_slice_calculated_callback = None
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
'''
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current CrossAttention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
Pass None to use the default attention calculation.
:return:
'''
self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
self.slicing_strategy_getter = getter
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
self.attention_slice_calculated_callback = callback
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
# calculate attention scores
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
# calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
if self.attention_slice_calculated_callback is not None:
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
hidden_states = torch.bmm(attention_slice, value)
return hidden_states
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
slicing_strategy_getter = self.slicing_strategy_getter
if slicing_strategy_getter is not None:
(dim, slice_size) = slicing_strategy_getter(self)
if dim is not None:
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
if dim == 0:
return self.einsum_op_slice_dim0(q, k, v, slice_size)
elif dim == 1:
return self.einsum_op_slice_dim1(q, k, v, slice_size)
# fallback for when there is no saved strategy, or saved strategy does not slice
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_invokeai_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda':
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v)
if q.device.type == 'mps' or q.device.type == 'cpu':
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
def remove_cross_attention_control(model): def remove_cross_attention_control(model):
remove_attention_function(model) remove_attention_function(model)
@ -210,8 +337,7 @@ def setup_cross_attention_control(model, context: Context):
inject_attention_function(model, context) inject_attention_function(model, context)
def get_attention_modules(model, which: CrossAttentionType): def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
# cross_attention_class: type = ldm.modules.attention.CrossAttention
cross_attention_class: type = InvokeAIDiffusersCrossAttention cross_attention_class: type = InvokeAIDiffusersCrossAttention
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
attention_module_tuples = [(name,module) for name, module in model.named_modules() if attention_module_tuples = [(name,module) for name, module in model.named_modules() if
@ -220,7 +346,7 @@ def get_attention_modules(model, which: CrossAttentionType):
expected_count = 16 expected_count = 16
if cross_attention_modules_in_model_count != expected_count: if cross_attention_modules_in_model_count != expected_count:
# non-fatal error but .swap() won't work. # non-fatal error but .swap() won't work.
print(f"Error! CrossAttentionControl found an unexpected number of InvokeAICrossAttention modules in the model " + print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " +
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " + f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " +
f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " + f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " +
f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " + f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " +
@ -266,7 +392,7 @@ def inject_attention_function(unet, context: Context):
return attention_slice return attention_slice
cross_attention_modules = get_attention_modules(unet, CrossAttentionType.TOKENS) + get_attention_modules(unet, CrossAttentionType.SELF) cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for identifier, module in cross_attention_modules: for identifier, module in cross_attention_modules:
module.identifier = identifier module.identifier = identifier
try: try:
@ -282,7 +408,7 @@ def inject_attention_function(unet, context: Context):
def remove_attention_function(unet): def remove_attention_function(unet):
cross_attention_modules = get_attention_modules(unet, CrossAttentionType.TOKENS) + get_attention_modules(unet, CrossAttentionType.SELF) cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for identifier, module in cross_attention_modules: for identifier, module in cross_attention_modules:
try: try:
# clear wrangler callback # clear wrangler callback
@ -316,7 +442,6 @@ def get_mem_free_total(device):
return mem_free_total return mem_free_total
class InvokeAICrossAttentionMixin: class InvokeAICrossAttentionMixin:
""" """
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
@ -452,4 +577,3 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention,
hidden_states = self.reshape_batch_dim_to_heads(damian_result) hidden_states = self.reshape_batch_dim_to_heads(damian_result)
return hidden_states return hidden_states

View File

@ -0,0 +1,95 @@
import math
import PIL
import torch
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode
from ldm.models.diffusion.cross_attention_control import get_cross_attention_modules, CrossAttentionType
class AttentionMapSaver():
def __init__(self, token_ids: range, latents_shape: torch.Size):
self.token_ids = token_ids
self.latents_shape = latents_shape
#self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
self.collated_maps = {}
def clear_maps(self):
self.collated_maps = {}
def add_attention_maps(self, maps: torch.Tensor, key: str):
"""
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
:return: None
"""
key_and_size = f'{key}_{maps.shape[1]}'
# extract desired tokens
maps = maps[:, :, self.token_ids]
# merge attention heads to a single map per token
maps = torch.sum(maps, 0)
# store
if key_and_size not in self.collated_maps:
self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu')
self.collated_maps[key_and_size] += maps.cpu()
def write_maps_to_disk(self, path: str):
pil_image = self.get_stacked_maps_image()
pil_image.save(path, 'PNG')
def get_stacked_maps_image(self) -> PIL.Image:
"""
Scale all collected attention maps to the same size, blend them together and return as an image.
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
"""
num_tokens = len(self.token_ids)
if num_tokens == 0:
return None
latents_height = self.latents_shape[0]
latents_width = self.latents_shape[1]
merged = None
for key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens
# but we want [N, H, W]
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
this_maps_height = int(float(latents_height) * this_scale_factor)
this_maps_width = int(float(latents_width) * this_scale_factor)
# and we need to do some dimension juggling
maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width])
# scale to output size if necessary
if this_scale_factor != 1:
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
# normalize
maps_min = torch.min(maps)
maps_range = torch.max(maps) - maps_min
#print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
maps_normalized = (maps - maps_min) / maps_range
# expand to (-0.1, 1.1) and clamp
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
# merge together, producing a vertical stack
maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width])
if merged is None:
merged = maps_stacked
else:
# screen blend
merged = 1 - (1 - maps_stacked)*(1 - merged)
if merged is None:
return None
merged_bytes = merged.mul(0xff).byte()
return PIL.Image.fromarray(merged_bytes.numpy(), mode='L')

View File

@ -4,6 +4,7 @@ import k_diffusion as K
import torch import torch
from torch import nn from torch import nn
from .cross_attention_map_saving import AttentionMapSaver
from .sampler import Sampler from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
@ -36,6 +37,7 @@ class CFGDenoiser(nn.Module):
self.invokeai_diffuser = InvokeAIDiffuserComponent(model, self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond)) model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
def prepare_to_sample(self, t_enc, **kwargs): def prepare_to_sample(self, t_enc, **kwargs):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None) extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
@ -106,12 +108,12 @@ class KSampler(Sampler):
else: else:
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})') print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
self.sigmas = self.karras_sigmas self.sigmas = self.karras_sigmas
# ALERT: We are completely overriding the sample() method in the base class, which # ALERT: We are completely overriding the sample() method in the base class, which
# means that inpainting will not work. To get this to work we need to be able to # means that inpainting will not work. To get this to work we need to be able to
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way # modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
# in the lstein/k-diffusion branch. # in the lstein/k-diffusion branch.
@torch.no_grad() @torch.no_grad()
def decode( def decode(
self, self,
@ -145,7 +147,7 @@ class KSampler(Sampler):
@torch.no_grad() @torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
return x0 return x0
# Most of these arguments are ignored and are only present for compatibility with # Most of these arguments are ignored and are only present for compatibility with
# other samples # other samples
@torch.no_grad() @torch.no_grad()
@ -158,6 +160,7 @@ class KSampler(Sampler):
callback=None, callback=None,
normals_sequence=None, normals_sequence=None,
img_callback=None, img_callback=None,
attention_maps_callback=None,
quantize_x0=False, quantize_x0=False,
eta=0.0, eta=0.0,
mask=None, mask=None,
@ -171,7 +174,7 @@ class KSampler(Sampler):
log_every_t=100, log_every_t=100,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
extra_conditioning_info=None, extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
threshold = 0, threshold = 0,
perlin = 0, perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -204,6 +207,12 @@ class KSampler(Sampler):
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info) model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
attention_map_token_ids = range(1, extra_conditioning_info.tokens_count_including_eos_bos - 1)
attention_maps_saver = None if attention_maps_callback is None else AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
if attention_maps_callback is not None:
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
extra_args = { extra_args = {
'cond': conditioning, 'cond': conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,
@ -217,6 +226,8 @@ class KSampler(Sampler):
), ),
None, None,
) )
if attention_maps_callback is not None:
attention_maps_callback(attention_maps_saver)
return sampling_result return sampling_result
# this code will support inpainting if and when ksampler API modified or # this code will support inpainting if and when ksampler API modified or
@ -248,7 +259,7 @@ class KSampler(Sampler):
# terrible, confusing names here # terrible, confusing names here
steps = self.ddim_num_steps steps = self.ddim_num_steps
t_enc = self.t_enc t_enc = self.t_enc
# sigmas is a full steps in length, but t_enc might # sigmas is a full steps in length, but t_enc might
# be less. We start in the middle of the sigma array # be less. We start in the middle of the sigma array
# and work our way to the end after t_enc steps. # and work our way to the end after t_enc steps.
@ -280,7 +291,7 @@ class KSampler(Sampler):
return x_T + x return x_T + x
else: else:
return x return x
def prepare_to_sample(self,t_enc,**kwargs): def prepare_to_sample(self,t_enc,**kwargs):
self.t_enc = t_enc self.t_enc = t_enc
self.model_wrap = None self.model_wrap = None

View File

@ -5,8 +5,8 @@ from typing import Callable, Optional, Union
import torch import torch
from ldm.models.diffusion.cross_attention_control import Arguments, \ from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, CrossAttentionType
from ldm.modules.attention import get_mem_free_total from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
class InvokeAIDiffuserComponent: class InvokeAIDiffuserComponent:
@ -21,7 +21,8 @@ class InvokeAIDiffuserComponent:
class ExtraConditioningInfo: class ExtraConditioningInfo:
def __init__(self, cross_attention_control_args: Optional[Arguments]): def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]):
self.tokens_count_including_eos_bos = tokens_count_including_eos_bos
self.cross_attention_control_args = cross_attention_control_args self.cross_attention_control_args = cross_attention_control_args
@property @property
@ -53,7 +54,25 @@ class InvokeAIDiffuserComponent:
self.cross_attention_control_context = None self.cross_attention_control_context = None
remove_cross_attention_control(self.model) remove_cross_attention_control(self.model)
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
if dim is not None:
# sliced tokens attention map saving is not implemented
return
saver.add_attention_maps(slice, key)
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for identifier, module in tokens_cross_attention_modules:
key = ('down' if identifier.startswith('down') else
'up' if identifier.startswith('up') else
'mid')
module.set_attention_slice_calculated_callback(
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key))
def remove_attention_map_saving(self):
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None)
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: Union[torch.Tensor,dict], unconditioning: Union[torch.Tensor,dict],

View File

@ -10,8 +10,6 @@ from einops import rearrange, repeat
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin
from ldm.modules.diffusionmodules.util import checkpoint from ldm.modules.diffusionmodules.util import checkpoint
import psutil
def exists(val): def exists(val):
return val is not None return val is not None
@ -165,10 +163,10 @@ def get_mem_free_total(device):
return mem_free_total return mem_free_total
class CrossAttention(nn.Module, InvokeAICrossAttentionMixin): class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
print(f"Warning! ldm.modules.attention.CrossAttention is no longer being maintained. Please use InvokeAICrossAttention instead.") print(f"Warning! ldm.modules.attention.CrossAttention is no longer being maintained. Please use InvokeAICrossAttention instead.")
super().__init__() super().__init__()
InvokeAICrossAttentionMixin.__init__(self)
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -184,7 +182,6 @@ class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
nn.Dropout(dropout) nn.Dropout(dropout)
) )
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
@ -196,7 +193,7 @@ class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
# prevent scale being applied twice # don't apply scale twice
cached_scale = self.scale cached_scale = self.scale
self.scale = 1 self.scale = 1
r = self.get_invokeai_attention_mem_efficient(q, k, v) r = self.get_invokeai_attention_mem_efficient(q, k, v)