Merge remote-tracking branch 'origin/main' into dev/diffusers

# Conflicts:
#	backend/invoke_ai_web_server.py
#	ldm/generate.py
#	ldm/invoke/CLI.py
#	ldm/invoke/generator/base.py
#	ldm/invoke/generator/txt2img.py
#	ldm/models/diffusion/cross_attention_control.py
#	ldm/modules/attention.py
This commit is contained in:
Kevin Turner 2022-12-10 08:43:37 -08:00
commit 63532226a5
13 changed files with 515 additions and 181 deletions

View File

@ -21,6 +21,7 @@ from backend.modules.get_canvas_generation_mode import (
get_canvas_generation_mode,
)
from backend.modules.parameters import parameters_to_command
from ldm.generate import Generate
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
from ldm.invoke.generator.inpaint import infill_methods
@ -38,7 +39,7 @@ if not os.path.isabs(args.outdir):
class InvokeAIWebServer:
def __init__(self, generate, gfpgan, codeformer, esrgan) -> None:
def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None:
self.host = args.host
self.port = args.port
@ -906,16 +907,13 @@ class InvokeAIWebServer:
},
)
if generation_parameters["progress_latents"]:
image = self.generate.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_base64 = "data:image/png;base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
img_base64 = image_to_dataURL(image)
self.socketio.emit(
"intermediateResult",
{
@ -933,7 +931,7 @@ class InvokeAIWebServer:
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
def image_done(image, seed, first_seed):
def image_done(image, seed, first_seed, attention_maps_image=None):
if self.canceled.is_set():
raise CanceledException
@ -1095,6 +1093,12 @@ class InvokeAIWebServer:
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
tokens = None if type(parsed_prompt) is Blend else \
get_tokens_for_prompt(self.generate.model, parsed_prompt)
attention_maps_image_base64_url = None if attention_maps_image is None \
else image_to_dataURL(attention_maps_image)
self.socketio.emit(
"generationResult",
{
@ -1107,6 +1111,8 @@ class InvokeAIWebServer:
"height": height,
"boundingBox": original_bounding_box,
"generationMode": generation_parameters["generation_mode"],
"attentionMaps": attention_maps_image_base64_url,
"tokens": tokens,
},
)
eventlet.sleep(0)
@ -1118,7 +1124,7 @@ class InvokeAIWebServer:
self.generate.prompt2image(
**generation_parameters,
step_callback=image_progress,
image_callback=image_done,
image_callback=image_done
)
except KeyboardInterrupt:
@ -1565,6 +1571,19 @@ def dataURL_to_image(dataURL: str) -> ImageType:
)
return image
"""
Converts an image into a base64 image dataURL.
"""
def image_to_dataURL(image: ImageType) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_base64 = "data:image/png;base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
return image_base64
"""
Converts a base64 image dataURL into bytes.

View File

@ -16,7 +16,6 @@ import numpy as np
import skimage
import torch
import transformers
from PIL import Image, ImageOps
from diffusers import HeunDiscreteScheduler
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
@ -27,6 +26,7 @@ from diffusers.schedulers.scheduling_ipndm import IPNDMScheduler
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
from diffusers.schedulers.scheduling_pndm import PNDMScheduler
from omegaconf import OmegaConf
from PIL import Image, ImageOps
from pytorch_lightning import seed_everything, logging
from ldm.invoke.args import metadata_from_png
@ -461,7 +461,7 @@ class Generate:
try:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model,
skip_normalize=skip_normalize,
skip_normalize_legacy_blend=skip_normalize,
log_tokens =self.log_tokenization
)
@ -613,8 +613,8 @@ class Generate:
# todo: cross-attention control
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model,
skip_normalize=opt.skip_normalize,
log_tokens =opt.log_tokenization
skip_normalize_legacy_blend=opt.skip_normalize,
log_tokens =ldm.invoke.conditioning.log_tokenization
)
if tool in ('gfpgan','codeformer','upscale'):

View File

@ -8,6 +8,7 @@ import time
import traceback
import yaml
from ldm.generate import Generate
from ldm.invoke.globals import Globals
from ldm.invoke.prompt_parser import PromptParser
from ldm.invoke.readline import get_completer, Completer
@ -282,7 +283,7 @@ def main_loop(gen, opt):
prefix = file_writer.unique_prefix()
step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None):
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None, attention_maps_image=None):
# note the seed is the seed of the current image
# the first_seed is the original seed that noise is added to
# when the -v switch is used to generate variations
@ -790,7 +791,7 @@ def get_next_command(infile=None) -> str: # command string
print(f'#{command}')
return command
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan):
print('\n* --web was specified, starting web server...')
from backend.invoke_ai_web_server import InvokeAIWebServer
# Change working directory to the stable-diffusion directory

View File

@ -7,20 +7,46 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
'''
import re
from difflib import SequenceMatcher
from typing import Union
import torch
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment
from ..models.diffusion import cross_attention_control
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
prompt, negative_prompt = get_prompt_structure(prompt_string,
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
return conditioning
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> (
Union[FlattenedPrompt, Blend], FlattenedPrompt):
"""
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
"""
prompt, negative_prompt = _parse_prompt_string(prompt_string,
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
return prompt, negative_prompt
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]:
text_fragments = [x.text if type(x) is Fragment else
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
str(x))
for x in parsed_prompt.children]
text = " ".join(text_fragments)
tokens = model.cond_stage_model.tokenizer.tokenize(text)
return tokens
def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]:
# Extract Unconditioned Words From Prompt
unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]'
@ -39,7 +65,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
pp = PromptParser()
parsed_prompt: Union[FlattenedPrompt, Blend] = None
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned)
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned, skip_normalize_legacy_blend)
if legacy_blend is not None:
parsed_prompt = legacy_blend
else:
@ -47,129 +73,150 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
return parsed_prompt, parsed_negative_prompt
def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt,
model, log_tokens=False) \
-> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]:
"""
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
"""
if log_tokens:
print(f">> Parsed prompt to {parsed_prompt}")
print(f">> Parsed negative prompt to {parsed_negative_prompt}")
conditioning = None
cac_args:cross_attention_control.Arguments = None
cac_args: cross_attention_control.Arguments = None
if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt
embeddings_to_blend = None
for i,flattened_prompt in enumerate(blend.prompts):
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
blend.weights,
normalize=blend.normalize_weights)
else:
flattened_prompt: FlattenedPrompt = parsed_prompt
wants_cross_attention_control = type(flattened_prompt) is not Blend \
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
if wants_cross_attention_control:
original_prompt = FlattenedPrompt()
edited_prompt = FlattenedPrompt()
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
original_token_count = 0
edited_token_count = 0
edit_options = []
edit_opcodes = []
# beginning of sequence
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
for fragment in flattened_prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited)
conditioning = _get_conditioning_for_blend(model, parsed_prompt, log_tokens)
elif type(parsed_prompt) is FlattenedPrompt:
if parsed_prompt.wants_cross_attention_control:
conditioning, cac_args = _get_conditioning_for_cross_attention_control(model, parsed_prompt, log_tokens)
to_replace_token_count = get_tokens_length(model, fragment.original)
replacement_token_count = get_tokens_length(model, fragment.edited)
edit_opcodes.append(('replace',
original_token_count, original_token_count + to_replace_token_count,
edited_token_count, edited_token_count + replacement_token_count
))
original_token_count += to_replace_token_count
edited_token_count += replacement_token_count
edit_options.append(fragment.options)
#elif type(fragment) is CrossAttentionControlAppend:
# edited_prompt.append(fragment.fragment)
else:
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
count = get_tokens_length(model, [fragment])
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
edit_options.append(None)
original_token_count += count
edited_token_count += count
# end of sequence
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
original_prompt,
log_tokens=log_tokens,
log_display_label="(.swap originals)")
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
edited_prompt,
log_tokens=log_tokens,
log_display_label="(.swap replacements)")
conditioning = original_embeddings
edited_conditioning = edited_embeddings
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = cross_attention_control.Arguments(
edited_conditioning = edited_conditioning,
edit_opcodes = edit_opcodes,
edit_options = edit_options
)
else:
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label="(prompt)")
conditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
parsed_prompt,
log_tokens=log_tokens,
log_display_label="(prompt)")
else:
raise ValueError(f"parsed_prompt is '{type(parsed_prompt)}' which is not a supported prompt type")
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
parsed_negative_prompt,
log_tokens=log_tokens,
log_display_label="(unconditioning)")
unconditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
parsed_negative_prompt,
log_tokens=log_tokens,
log_display_label="(unconditioning)")
if isinstance(conditioning, dict):
# hybrid conditioning is in play
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning)
if cac_args is not None:
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
print(
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
cac_args = None
eos_token_index = 1
if type(parsed_prompt) is not Blend:
tokens = get_tokens_for_prompt(model, parsed_prompt)
eos_token_index = len(tokens)+1
return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=eos_token_index + 1,
cross_attention_control_args=cac_args
)
)
def build_token_edit_opcodes(original_tokens, edited_tokens):
original_tokens = original_tokens.cpu().numpy()[0]
edited_tokens = edited_tokens.cpu().numpy()[0]
def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True):
original_prompt = FlattenedPrompt()
edited_prompt = FlattenedPrompt()
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
original_token_count = 0
edited_token_count = 0
edit_options = []
edit_opcodes = []
# beginning of sequence
edit_opcodes.append(
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
for fragment in prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited)
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
to_replace_token_count = _get_tokens_length(model, fragment.original)
replacement_token_count = _get_tokens_length(model, fragment.edited)
edit_opcodes.append(('replace',
original_token_count, original_token_count + to_replace_token_count,
edited_token_count, edited_token_count + replacement_token_count
))
original_token_count += to_replace_token_count
edited_token_count += replacement_token_count
edit_options.append(fragment.options)
# elif type(fragment) is CrossAttentionControlAppend:
# edited_prompt.append(fragment.fragment)
else:
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None):
count = _get_tokens_length(model, [fragment])
edit_opcodes.append(('equal', original_token_count, original_token_count + count, edited_token_count,
edited_token_count + count))
edit_options.append(None)
original_token_count += count
edited_token_count += count
# end of sequence
edit_opcodes.append(
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
original_embeddings, original_tokens = _get_embeddings_and_tokens_for_prompt(model,
original_prompt,
log_tokens=log_tokens,
log_display_label="(.swap originals)")
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = _get_embeddings_and_tokens_for_prompt(model,
edited_prompt,
log_tokens=log_tokens,
log_display_label="(.swap replacements)")
conditioning = original_embeddings
edited_conditioning = edited_embeddings
# print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = cross_attention_control.Arguments(
edited_conditioning=edited_conditioning,
edit_opcodes=edit_opcodes,
edit_options=edit_options
)
return conditioning, cac_args
def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
embeddings_to_blend = None
for i, flattened_prompt in enumerate(blend.prompts):
this_embedding, _ = _get_embeddings_and_tokens_for_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
blend.weights,
normalize=blend.normalize_weights)
return conditioning
def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False,
log_display_label: str = None):
if type(flattened_prompt) is not FlattenedPrompt:
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
fragments = [x.text for x in flattened_prompt.children]
@ -181,12 +228,14 @@ def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: Fl
return embeddings, tokens
def get_tokens_length(model, fragments: list[Fragment]):
def _get_tokens_length(model, fragments: list[Fragment]):
fragment_texts = [x.text for x in fragments]
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
return sum([len(x) for x in tokens])
def flatten_hybrid_conditioning(uncond, cond):
def _flatten_hybrid_conditioning(uncond, cond):
'''
This handles the choice between a conditional conditioning
that is a tensor (used by cross attention) vs one that has additional
@ -205,4 +254,29 @@ def flatten_hybrid_conditioning(uncond, cond):
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
return uncond, cond_flattened
def log_tokenization(text, model, display_label=None):
""" shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
"""
tokens = model.cond_stage_model.tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens - usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -19,6 +19,7 @@ from pytorch_lightning import seed_everything
from tqdm import trange
from ldm.invoke.devices import choose_autocast
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ldm.models.diffusion.ddpm import DiffusionWrapper
from ldm.util import rand_perlin_2d
@ -62,9 +63,12 @@ class Generator:
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
attention_maps_callback = None,
**kwargs):
scope = choose_autocast(self.precision)
self.safety_checker = safety_checker
attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
make_image = self.get_make_image(
prompt,
sampler = sampler,
@ -74,6 +78,7 @@ class Generator:
step_callback = step_callback,
threshold = threshold,
perlin = perlin,
attention_maps_callback = attention_maps_callback,
**kwargs
)
results = []
@ -109,7 +114,7 @@ class Generator:
results.append([image, seed])
if image_callback is not None:
image_callback(image, seed, first_seed=first_seed)
image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_images[-1])
seed = self.new_seed()

View File

@ -15,6 +15,7 @@ class Txt2Img(Generator):
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
attention_maps_callback=None,
**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
@ -39,6 +40,7 @@ class Txt2Img(Generator):
extra_conditioning_info=extra_conditioning_info,
# TODO: eta = ddim_eta,
# TODO: threshold = threshold,
attention_maps_callback = attention_maps_callback,
)
return pipeline.numpy_to_pil(pipeline_output.images)[0]

View File

@ -3,7 +3,7 @@ from typing import Union, Optional
import re
import pyparsing as pp
'''
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
weighted subprompts.
Useful class exports:
@ -69,6 +69,12 @@ class FlattenedPrompt():
return len(self.children) == 0 or \
(len(self.children) == 1 and len(self.children[0].text) == 0)
@property
def wants_cross_attention_control(self):
return any(
[issubclass(type(x), CrossAttentionControlledFragment) for x in self.children]
)
def __repr__(self):
return f"FlattenedPrompt:{self.children}"
def __eq__(self, other):
@ -240,6 +246,12 @@ class Blend():
self.weights = weights
self.normalize_weights = normalize_weights
@property
def wants_cross_attention_control(self):
# blends cannot cross-attention control
return False
def __repr__(self):
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
def __eq__(self, other):
@ -277,8 +289,8 @@ class PromptParser():
return self.flatten(root[0])
def parse_legacy_blend(self, text: str) -> Optional[Blend]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False)
def parse_legacy_blend(self, text: str, skip_normalize: bool) -> Optional[Blend]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1:
return None
strings = [x[0] for x in weighted_subprompts]
@ -287,7 +299,7 @@ class PromptParser():
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
@ -641,27 +653,3 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, display_label=None):
tokens = model.cond_stage_model.tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -53,7 +53,6 @@ COMMANDS = (
'--codeformer_fidelity','-cf',
'--upscale','-U',
'-save_orig','--save_original',
'--skip_normalize','-x',
'--log_tokenization','-t',
'--hires_fix',
'--inpaint_replace','-r',
@ -117,19 +116,19 @@ class Completer(object):
# extensions defined, so go directly into path completion mode
if self.extensions is not None:
self.matches = self._path_completions(text, state, self.extensions)
# looking for an image file
elif re.search(path_regexp,buffer):
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
# looking for a seed
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
self.matches= self._seed_completions(text,state)
elif re.search('<[\w-]*$',buffer):
elif re.search('<[\w-]*$',buffer):
self.matches= self._concept_completions(text,state)
# looking for a model
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
self.matches= self._model_completions(text, state)
@ -227,7 +226,7 @@ class Completer(object):
if h_len < 1:
print('<empty history>')
return
for i in range(0,h_len):
line = self.get_history_item(i+1)
if match and match not in line:
@ -367,7 +366,7 @@ class DummyCompleter(Completer):
def __init__(self,options):
super().__init__(options)
self.history = list()
def add_history(self,line):
self.history.append(line)

View File

@ -7,11 +7,9 @@ import torch
import diffusers
from torch import nn
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
"""
@ -68,11 +66,11 @@ class Context:
self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model):
for name,module in get_attention_modules(model, CrossAttentionType.SELF):
for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF):
if name in self.self_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
self.self_cross_attention_module_identifiers.append(name)
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS):
for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
if name in self.tokens_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
self.tokens_cross_attention_module_identifiers.append(name)
@ -175,6 +173,135 @@ class Context:
map_dict[offset] = slice.to('cpu')
class InvokeAICrossAttentionMixin:
"""
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
and dymamic slicing strategy selection.
"""
def __init__(self):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.attention_slice_wrangler = None
self.slicing_strategy_getter = None
self.attention_slice_calculated_callback = None
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
'''
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current CrossAttention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
Pass None to use the default attention calculation.
:return:
'''
self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
self.slicing_strategy_getter = getter
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
self.attention_slice_calculated_callback = callback
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
# calculate attention scores
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
# calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
if self.attention_slice_calculated_callback is not None:
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
hidden_states = torch.bmm(attention_slice, value)
return hidden_states
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
slicing_strategy_getter = self.slicing_strategy_getter
if slicing_strategy_getter is not None:
(dim, slice_size) = slicing_strategy_getter(self)
if dim is not None:
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
if dim == 0:
return self.einsum_op_slice_dim0(q, k, v, slice_size)
elif dim == 1:
return self.einsum_op_slice_dim1(q, k, v, slice_size)
# fallback for when there is no saved strategy, or saved strategy does not slice
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_invokeai_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda':
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v)
if q.device.type == 'mps' or q.device.type == 'cpu':
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
def remove_cross_attention_control(model):
remove_attention_function(model)
@ -210,8 +337,7 @@ def setup_cross_attention_control(model, context: Context):
inject_attention_function(model, context)
def get_attention_modules(model, which: CrossAttentionType):
# cross_attention_class: type = ldm.modules.attention.CrossAttention
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
cross_attention_class: type = InvokeAIDiffusersCrossAttention
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
@ -220,7 +346,7 @@ def get_attention_modules(model, which: CrossAttentionType):
expected_count = 16
if cross_attention_modules_in_model_count != expected_count:
# non-fatal error but .swap() won't work.
print(f"Error! CrossAttentionControl found an unexpected number of InvokeAICrossAttention modules in the model " +
print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " +
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " +
f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " +
f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " +
@ -266,7 +392,7 @@ def inject_attention_function(unet, context: Context):
return attention_slice
cross_attention_modules = get_attention_modules(unet, CrossAttentionType.TOKENS) + get_attention_modules(unet, CrossAttentionType.SELF)
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for identifier, module in cross_attention_modules:
module.identifier = identifier
try:
@ -282,7 +408,7 @@ def inject_attention_function(unet, context: Context):
def remove_attention_function(unet):
cross_attention_modules = get_attention_modules(unet, CrossAttentionType.TOKENS) + get_attention_modules(unet, CrossAttentionType.SELF)
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for identifier, module in cross_attention_modules:
try:
# clear wrangler callback
@ -316,7 +442,6 @@ def get_mem_free_total(device):
return mem_free_total
class InvokeAICrossAttentionMixin:
"""
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
@ -452,4 +577,3 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention,
hidden_states = self.reshape_batch_dim_to_heads(damian_result)
return hidden_states

View File

@ -0,0 +1,95 @@
import math
import PIL
import torch
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode
from ldm.models.diffusion.cross_attention_control import get_cross_attention_modules, CrossAttentionType
class AttentionMapSaver():
def __init__(self, token_ids: range, latents_shape: torch.Size):
self.token_ids = token_ids
self.latents_shape = latents_shape
#self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
self.collated_maps = {}
def clear_maps(self):
self.collated_maps = {}
def add_attention_maps(self, maps: torch.Tensor, key: str):
"""
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
:return: None
"""
key_and_size = f'{key}_{maps.shape[1]}'
# extract desired tokens
maps = maps[:, :, self.token_ids]
# merge attention heads to a single map per token
maps = torch.sum(maps, 0)
# store
if key_and_size not in self.collated_maps:
self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu')
self.collated_maps[key_and_size] += maps.cpu()
def write_maps_to_disk(self, path: str):
pil_image = self.get_stacked_maps_image()
pil_image.save(path, 'PNG')
def get_stacked_maps_image(self) -> PIL.Image:
"""
Scale all collected attention maps to the same size, blend them together and return as an image.
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
"""
num_tokens = len(self.token_ids)
if num_tokens == 0:
return None
latents_height = self.latents_shape[0]
latents_width = self.latents_shape[1]
merged = None
for key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens
# but we want [N, H, W]
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
this_maps_height = int(float(latents_height) * this_scale_factor)
this_maps_width = int(float(latents_width) * this_scale_factor)
# and we need to do some dimension juggling
maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width])
# scale to output size if necessary
if this_scale_factor != 1:
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
# normalize
maps_min = torch.min(maps)
maps_range = torch.max(maps) - maps_min
#print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
maps_normalized = (maps - maps_min) / maps_range
# expand to (-0.1, 1.1) and clamp
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
# merge together, producing a vertical stack
maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width])
if merged is None:
merged = maps_stacked
else:
# screen blend
merged = 1 - (1 - maps_stacked)*(1 - merged)
if merged is None:
return None
merged_bytes = merged.mul(0xff).byte()
return PIL.Image.fromarray(merged_bytes.numpy(), mode='L')

View File

@ -4,6 +4,7 @@ import k_diffusion as K
import torch
from torch import nn
from .cross_attention_map_saving import AttentionMapSaver
from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
@ -36,6 +37,7 @@ class CFGDenoiser(nn.Module):
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
def prepare_to_sample(self, t_enc, **kwargs):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
@ -106,12 +108,12 @@ class KSampler(Sampler):
else:
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
self.sigmas = self.karras_sigmas
# ALERT: We are completely overriding the sample() method in the base class, which
# means that inpainting will not work. To get this to work we need to be able to
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
# in the lstein/k-diffusion branch.
@torch.no_grad()
def decode(
self,
@ -145,7 +147,7 @@ class KSampler(Sampler):
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
return x0
# Most of these arguments are ignored and are only present for compatibility with
# other samples
@torch.no_grad()
@ -158,6 +160,7 @@ class KSampler(Sampler):
callback=None,
normals_sequence=None,
img_callback=None,
attention_maps_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
@ -171,7 +174,7 @@ class KSampler(Sampler):
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
extra_conditioning_info=None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
threshold = 0,
perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -204,6 +207,12 @@ class KSampler(Sampler):
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
attention_map_token_ids = range(1, extra_conditioning_info.tokens_count_including_eos_bos - 1)
attention_maps_saver = None if attention_maps_callback is None else AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
if attention_maps_callback is not None:
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,
@ -217,6 +226,8 @@ class KSampler(Sampler):
),
None,
)
if attention_maps_callback is not None:
attention_maps_callback(attention_maps_saver)
return sampling_result
# this code will support inpainting if and when ksampler API modified or
@ -248,7 +259,7 @@ class KSampler(Sampler):
# terrible, confusing names here
steps = self.ddim_num_steps
t_enc = self.t_enc
# sigmas is a full steps in length, but t_enc might
# be less. We start in the middle of the sigma array
# and work our way to the end after t_enc steps.
@ -280,7 +291,7 @@ class KSampler(Sampler):
return x_T + x
else:
return x
def prepare_to_sample(self,t_enc,**kwargs):
self.t_enc = t_enc
self.model_wrap = None

View File

@ -5,8 +5,8 @@ from typing import Callable, Optional, Union
import torch
from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context
from ldm.modules.attention import get_mem_free_total
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, CrossAttentionType
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
class InvokeAIDiffuserComponent:
@ -21,7 +21,8 @@ class InvokeAIDiffuserComponent:
class ExtraConditioningInfo:
def __init__(self, cross_attention_control_args: Optional[Arguments]):
def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]):
self.tokens_count_including_eos_bos = tokens_count_including_eos_bos
self.cross_attention_control_args = cross_attention_control_args
@property
@ -53,7 +54,25 @@ class InvokeAIDiffuserComponent:
self.cross_attention_control_context = None
remove_cross_attention_control(self.model)
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
if dim is not None:
# sliced tokens attention map saving is not implemented
return
saver.add_attention_maps(slice, key)
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for identifier, module in tokens_cross_attention_modules:
key = ('down' if identifier.startswith('down') else
'up' if identifier.startswith('up') else
'mid')
module.set_attention_slice_calculated_callback(
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key))
def remove_attention_map_saving(self):
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None)
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: Union[torch.Tensor,dict],

View File

@ -10,8 +10,6 @@ from einops import rearrange, repeat
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin
from ldm.modules.diffusionmodules.util import checkpoint
import psutil
def exists(val):
return val is not None
@ -165,10 +163,10 @@ def get_mem_free_total(device):
return mem_free_total
class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
print(f"Warning! ldm.modules.attention.CrossAttention is no longer being maintained. Please use InvokeAICrossAttention instead.")
super().__init__()
InvokeAICrossAttentionMixin.__init__(self)
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@ -184,7 +182,6 @@ class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
@ -196,7 +193,7 @@ class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
# prevent scale being applied twice
# don't apply scale twice
cached_scale = self.scale
self.scale = 1
r = self.get_invokeai_attention_mem_efficient(q, k, v)