mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge remote-tracking branch 'origin/main' into dev/diffusers
# Conflicts: # backend/invoke_ai_web_server.py # ldm/generate.py # ldm/invoke/CLI.py # ldm/invoke/generator/base.py # ldm/invoke/generator/txt2img.py # ldm/models/diffusion/cross_attention_control.py # ldm/modules/attention.py
This commit is contained in:
commit
63532226a5
@ -21,6 +21,7 @@ from backend.modules.get_canvas_generation_mode import (
|
||||
get_canvas_generation_mode,
|
||||
)
|
||||
from backend.modules.parameters import parameters_to_command
|
||||
from ldm.generate import Generate
|
||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
||||
from ldm.invoke.generator.inpaint import infill_methods
|
||||
@ -38,7 +39,7 @@ if not os.path.isabs(args.outdir):
|
||||
|
||||
|
||||
class InvokeAIWebServer:
|
||||
def __init__(self, generate, gfpgan, codeformer, esrgan) -> None:
|
||||
def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None:
|
||||
self.host = args.host
|
||||
self.port = args.port
|
||||
|
||||
@ -906,16 +907,13 @@ class InvokeAIWebServer:
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if generation_parameters["progress_latents"]:
|
||||
image = self.generate.sample_to_lowres_estimated_image(sample)
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
img_base64 = "data:image/png;base64," + base64.b64encode(
|
||||
buffered.getvalue()
|
||||
).decode("UTF-8")
|
||||
img_base64 = image_to_dataURL(image)
|
||||
self.socketio.emit(
|
||||
"intermediateResult",
|
||||
{
|
||||
@ -933,7 +931,7 @@ class InvokeAIWebServer:
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
def image_done(image, seed, first_seed):
|
||||
def image_done(image, seed, first_seed, attention_maps_image=None):
|
||||
if self.canceled.is_set():
|
||||
raise CanceledException
|
||||
|
||||
@ -1095,6 +1093,12 @@ class InvokeAIWebServer:
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
|
||||
tokens = None if type(parsed_prompt) is Blend else \
|
||||
get_tokens_for_prompt(self.generate.model, parsed_prompt)
|
||||
attention_maps_image_base64_url = None if attention_maps_image is None \
|
||||
else image_to_dataURL(attention_maps_image)
|
||||
|
||||
self.socketio.emit(
|
||||
"generationResult",
|
||||
{
|
||||
@ -1107,6 +1111,8 @@ class InvokeAIWebServer:
|
||||
"height": height,
|
||||
"boundingBox": original_bounding_box,
|
||||
"generationMode": generation_parameters["generation_mode"],
|
||||
"attentionMaps": attention_maps_image_base64_url,
|
||||
"tokens": tokens,
|
||||
},
|
||||
)
|
||||
eventlet.sleep(0)
|
||||
@ -1118,7 +1124,7 @@ class InvokeAIWebServer:
|
||||
self.generate.prompt2image(
|
||||
**generation_parameters,
|
||||
step_callback=image_progress,
|
||||
image_callback=image_done,
|
||||
image_callback=image_done
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
@ -1565,6 +1571,19 @@ def dataURL_to_image(dataURL: str) -> ImageType:
|
||||
)
|
||||
return image
|
||||
|
||||
"""
|
||||
Converts an image into a base64 image dataURL.
|
||||
"""
|
||||
|
||||
def image_to_dataURL(image: ImageType) -> str:
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
image_base64 = "data:image/png;base64," + base64.b64encode(
|
||||
buffered.getvalue()
|
||||
).decode("UTF-8")
|
||||
return image_base64
|
||||
|
||||
|
||||
|
||||
"""
|
||||
Converts a base64 image dataURL into bytes.
|
||||
|
@ -16,7 +16,6 @@ import numpy as np
|
||||
import skimage
|
||||
import torch
|
||||
import transformers
|
||||
from PIL import Image, ImageOps
|
||||
from diffusers import HeunDiscreteScheduler
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
@ -27,6 +26,7 @@ from diffusers.schedulers.scheduling_ipndm import IPNDMScheduler
|
||||
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
|
||||
from diffusers.schedulers.scheduling_pndm import PNDMScheduler
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image, ImageOps
|
||||
from pytorch_lightning import seed_everything, logging
|
||||
|
||||
from ldm.invoke.args import metadata_from_png
|
||||
@ -461,7 +461,7 @@ class Generate:
|
||||
try:
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
skip_normalize=skip_normalize,
|
||||
skip_normalize_legacy_blend=skip_normalize,
|
||||
log_tokens =self.log_tokenization
|
||||
)
|
||||
|
||||
@ -613,8 +613,8 @@ class Generate:
|
||||
# todo: cross-attention control
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
skip_normalize=opt.skip_normalize,
|
||||
log_tokens =opt.log_tokenization
|
||||
skip_normalize_legacy_blend=opt.skip_normalize,
|
||||
log_tokens =ldm.invoke.conditioning.log_tokenization
|
||||
)
|
||||
|
||||
if tool in ('gfpgan','codeformer','upscale'):
|
||||
|
@ -8,6 +8,7 @@ import time
|
||||
import traceback
|
||||
import yaml
|
||||
|
||||
from ldm.generate import Generate
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
from ldm.invoke.readline import get_completer, Completer
|
||||
@ -282,7 +283,7 @@ def main_loop(gen, opt):
|
||||
prefix = file_writer.unique_prefix()
|
||||
step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None
|
||||
|
||||
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None):
|
||||
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None, attention_maps_image=None):
|
||||
# note the seed is the seed of the current image
|
||||
# the first_seed is the original seed that noise is added to
|
||||
# when the -v switch is used to generate variations
|
||||
@ -790,7 +791,7 @@ def get_next_command(infile=None) -> str: # command string
|
||||
print(f'#{command}')
|
||||
return command
|
||||
|
||||
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
|
||||
def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan):
|
||||
print('\n* --web was specified, starting web server...')
|
||||
from backend.invoke_ai_web_server import InvokeAIWebServer
|
||||
# Change working directory to the stable-diffusion directory
|
||||
|
@ -7,20 +7,46 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
|
||||
|
||||
'''
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
|
||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment
|
||||
from ..models.diffusion import cross_attention_control
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
|
||||
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
prompt, negative_prompt = get_prompt_structure(prompt_string,
|
||||
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
||||
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> (
|
||||
Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
"""
|
||||
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
|
||||
"""
|
||||
prompt, negative_prompt = _parse_prompt_string(prompt_string,
|
||||
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
||||
return prompt, negative_prompt
|
||||
|
||||
|
||||
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]:
|
||||
text_fragments = [x.text if type(x) is Fragment else
|
||||
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
|
||||
str(x))
|
||||
for x in parsed_prompt.children]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
||||
return tokens
|
||||
|
||||
|
||||
def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]:
|
||||
# Extract Unconditioned Words From Prompt
|
||||
unconditioned_words = ''
|
||||
unconditional_regex = r'\[(.*?)\]'
|
||||
@ -39,7 +65,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
pp = PromptParser()
|
||||
|
||||
parsed_prompt: Union[FlattenedPrompt, Blend] = None
|
||||
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned)
|
||||
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned, skip_normalize_legacy_blend)
|
||||
if legacy_blend is not None:
|
||||
parsed_prompt = legacy_blend
|
||||
else:
|
||||
@ -47,129 +73,150 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
|
||||
|
||||
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
|
||||
return parsed_prompt, parsed_negative_prompt
|
||||
|
||||
|
||||
def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt,
|
||||
model, log_tokens=False) \
|
||||
-> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]:
|
||||
"""
|
||||
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
|
||||
"""
|
||||
|
||||
if log_tokens:
|
||||
print(f">> Parsed prompt to {parsed_prompt}")
|
||||
print(f">> Parsed negative prompt to {parsed_negative_prompt}")
|
||||
|
||||
conditioning = None
|
||||
cac_args:cross_attention_control.Arguments = None
|
||||
cac_args: cross_attention_control.Arguments = None
|
||||
|
||||
if type(parsed_prompt) is Blend:
|
||||
blend: Blend = parsed_prompt
|
||||
embeddings_to_blend = None
|
||||
for i,flattened_prompt in enumerate(blend.prompts):
|
||||
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
flattened_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
|
||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||
(embeddings_to_blend, this_embedding))
|
||||
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||
blend.weights,
|
||||
normalize=blend.normalize_weights)
|
||||
else:
|
||||
flattened_prompt: FlattenedPrompt = parsed_prompt
|
||||
wants_cross_attention_control = type(flattened_prompt) is not Blend \
|
||||
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
|
||||
if wants_cross_attention_control:
|
||||
original_prompt = FlattenedPrompt()
|
||||
edited_prompt = FlattenedPrompt()
|
||||
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
|
||||
original_token_count = 0
|
||||
edited_token_count = 0
|
||||
edit_options = []
|
||||
edit_opcodes = []
|
||||
# beginning of sequence
|
||||
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
|
||||
edit_options.append(None)
|
||||
original_token_count += 1
|
||||
edited_token_count += 1
|
||||
for fragment in flattened_prompt.children:
|
||||
if type(fragment) is CrossAttentionControlSubstitute:
|
||||
original_prompt.append(fragment.original)
|
||||
edited_prompt.append(fragment.edited)
|
||||
conditioning = _get_conditioning_for_blend(model, parsed_prompt, log_tokens)
|
||||
elif type(parsed_prompt) is FlattenedPrompt:
|
||||
if parsed_prompt.wants_cross_attention_control:
|
||||
conditioning, cac_args = _get_conditioning_for_cross_attention_control(model, parsed_prompt, log_tokens)
|
||||
|
||||
to_replace_token_count = get_tokens_length(model, fragment.original)
|
||||
replacement_token_count = get_tokens_length(model, fragment.edited)
|
||||
edit_opcodes.append(('replace',
|
||||
original_token_count, original_token_count + to_replace_token_count,
|
||||
edited_token_count, edited_token_count + replacement_token_count
|
||||
))
|
||||
original_token_count += to_replace_token_count
|
||||
edited_token_count += replacement_token_count
|
||||
edit_options.append(fragment.options)
|
||||
#elif type(fragment) is CrossAttentionControlAppend:
|
||||
# edited_prompt.append(fragment.fragment)
|
||||
else:
|
||||
# regular fragment
|
||||
original_prompt.append(fragment)
|
||||
edited_prompt.append(fragment)
|
||||
|
||||
count = get_tokens_length(model, [fragment])
|
||||
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
|
||||
edit_options.append(None)
|
||||
original_token_count += count
|
||||
edited_token_count += count
|
||||
# end of sequence
|
||||
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
|
||||
edit_options.append(None)
|
||||
original_token_count += 1
|
||||
edited_token_count += 1
|
||||
|
||||
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
original_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(.swap originals)")
|
||||
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
||||
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
||||
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
||||
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
||||
# token 'smiling' in the inactive 'cat' edit.
|
||||
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
||||
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
edited_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(.swap replacements)")
|
||||
|
||||
conditioning = original_embeddings
|
||||
edited_conditioning = edited_embeddings
|
||||
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
||||
cac_args = cross_attention_control.Arguments(
|
||||
edited_conditioning = edited_conditioning,
|
||||
edit_opcodes = edit_opcodes,
|
||||
edit_options = edit_options
|
||||
)
|
||||
else:
|
||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
flattened_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(prompt)")
|
||||
conditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||
parsed_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(prompt)")
|
||||
else:
|
||||
raise ValueError(f"parsed_prompt is '{type(parsed_prompt)}' which is not a supported prompt type")
|
||||
|
||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
parsed_negative_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(unconditioning)")
|
||||
unconditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||
parsed_negative_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(unconditioning)")
|
||||
if isinstance(conditioning, dict):
|
||||
# hybrid conditioning is in play
|
||||
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
|
||||
unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning)
|
||||
if cac_args is not None:
|
||||
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
||||
print(
|
||||
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
||||
cac_args = None
|
||||
|
||||
eos_token_index = 1
|
||||
if type(parsed_prompt) is not Blend:
|
||||
tokens = get_tokens_for_prompt(model, parsed_prompt)
|
||||
eos_token_index = len(tokens)+1
|
||||
return (
|
||||
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=eos_token_index + 1,
|
||||
cross_attention_control_args=cac_args
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_token_edit_opcodes(original_tokens, edited_tokens):
|
||||
original_tokens = original_tokens.cpu().numpy()[0]
|
||||
edited_tokens = edited_tokens.cpu().numpy()[0]
|
||||
def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True):
|
||||
original_prompt = FlattenedPrompt()
|
||||
edited_prompt = FlattenedPrompt()
|
||||
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
|
||||
original_token_count = 0
|
||||
edited_token_count = 0
|
||||
edit_options = []
|
||||
edit_opcodes = []
|
||||
# beginning of sequence
|
||||
edit_opcodes.append(
|
||||
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
|
||||
edit_options.append(None)
|
||||
original_token_count += 1
|
||||
edited_token_count += 1
|
||||
for fragment in prompt.children:
|
||||
if type(fragment) is CrossAttentionControlSubstitute:
|
||||
original_prompt.append(fragment.original)
|
||||
edited_prompt.append(fragment.edited)
|
||||
|
||||
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
|
||||
to_replace_token_count = _get_tokens_length(model, fragment.original)
|
||||
replacement_token_count = _get_tokens_length(model, fragment.edited)
|
||||
edit_opcodes.append(('replace',
|
||||
original_token_count, original_token_count + to_replace_token_count,
|
||||
edited_token_count, edited_token_count + replacement_token_count
|
||||
))
|
||||
original_token_count += to_replace_token_count
|
||||
edited_token_count += replacement_token_count
|
||||
edit_options.append(fragment.options)
|
||||
# elif type(fragment) is CrossAttentionControlAppend:
|
||||
# edited_prompt.append(fragment.fragment)
|
||||
else:
|
||||
# regular fragment
|
||||
original_prompt.append(fragment)
|
||||
edited_prompt.append(fragment)
|
||||
|
||||
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None):
|
||||
count = _get_tokens_length(model, [fragment])
|
||||
edit_opcodes.append(('equal', original_token_count, original_token_count + count, edited_token_count,
|
||||
edited_token_count + count))
|
||||
edit_options.append(None)
|
||||
original_token_count += count
|
||||
edited_token_count += count
|
||||
# end of sequence
|
||||
edit_opcodes.append(
|
||||
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
|
||||
edit_options.append(None)
|
||||
original_token_count += 1
|
||||
edited_token_count += 1
|
||||
original_embeddings, original_tokens = _get_embeddings_and_tokens_for_prompt(model,
|
||||
original_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(.swap originals)")
|
||||
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
||||
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
||||
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
||||
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
||||
# token 'smiling' in the inactive 'cat' edit.
|
||||
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
||||
edited_embeddings, edited_tokens = _get_embeddings_and_tokens_for_prompt(model,
|
||||
edited_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(.swap replacements)")
|
||||
conditioning = original_embeddings
|
||||
edited_conditioning = edited_embeddings
|
||||
# print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
||||
cac_args = cross_attention_control.Arguments(
|
||||
edited_conditioning=edited_conditioning,
|
||||
edit_opcodes=edit_opcodes,
|
||||
edit_options=edit_options
|
||||
)
|
||||
return conditioning, cac_args
|
||||
|
||||
|
||||
def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
|
||||
embeddings_to_blend = None
|
||||
for i, flattened_prompt in enumerate(blend.prompts):
|
||||
this_embedding, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||
flattened_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
|
||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||
(embeddings_to_blend, this_embedding))
|
||||
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||
blend.weights,
|
||||
normalize=blend.normalize_weights)
|
||||
return conditioning
|
||||
|
||||
|
||||
def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False,
|
||||
log_display_label: str = None):
|
||||
if type(flattened_prompt) is not FlattenedPrompt:
|
||||
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
||||
fragments = [x.text for x in flattened_prompt.children]
|
||||
@ -181,12 +228,14 @@ def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: Fl
|
||||
|
||||
return embeddings, tokens
|
||||
|
||||
def get_tokens_length(model, fragments: list[Fragment]):
|
||||
|
||||
def _get_tokens_length(model, fragments: list[Fragment]):
|
||||
fragment_texts = [x.text for x in fragments]
|
||||
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
||||
return sum([len(x) for x in tokens])
|
||||
|
||||
def flatten_hybrid_conditioning(uncond, cond):
|
||||
|
||||
def _flatten_hybrid_conditioning(uncond, cond):
|
||||
'''
|
||||
This handles the choice between a conditional conditioning
|
||||
that is a tensor (used by cross attention) vs one that has additional
|
||||
@ -205,4 +254,29 @@ def flatten_hybrid_conditioning(uncond, cond):
|
||||
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
|
||||
return uncond, cond_flattened
|
||||
|
||||
|
||||
|
||||
def log_tokenization(text, model, display_label=None):
|
||||
""" shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
"""
|
||||
|
||||
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace('</w>', ' ')
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if i < model.cond_stage_model.max_length:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f">> Tokens Discarded ({totalTokens - usedTokens}):\n{discarded}\x1b[0m"
|
||||
)
|
||||
|
@ -19,6 +19,7 @@ from pytorch_lightning import seed_everything
|
||||
from tqdm import trange
|
||||
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from ldm.models.diffusion.ddpm import DiffusionWrapper
|
||||
from ldm.util import rand_perlin_2d
|
||||
|
||||
@ -62,9 +63,12 @@ class Generator:
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
attention_maps_callback = None,
|
||||
**kwargs):
|
||||
scope = choose_autocast(self.precision)
|
||||
self.safety_checker = safety_checker
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler = sampler,
|
||||
@ -74,6 +78,7 @@ class Generator:
|
||||
step_callback = step_callback,
|
||||
threshold = threshold,
|
||||
perlin = perlin,
|
||||
attention_maps_callback = attention_maps_callback,
|
||||
**kwargs
|
||||
)
|
||||
results = []
|
||||
@ -109,7 +114,7 @@ class Generator:
|
||||
results.append([image, seed])
|
||||
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, first_seed=first_seed)
|
||||
image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_images[-1])
|
||||
|
||||
seed = self.new_seed()
|
||||
|
||||
|
@ -15,6 +15,7 @@ class Txt2Img(Generator):
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
|
||||
attention_maps_callback=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
@ -39,6 +40,7 @@ class Txt2Img(Generator):
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
# TODO: eta = ddim_eta,
|
||||
# TODO: threshold = threshold,
|
||||
attention_maps_callback = attention_maps_callback,
|
||||
)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
@ -3,7 +3,7 @@ from typing import Union, Optional
|
||||
import re
|
||||
import pyparsing as pp
|
||||
'''
|
||||
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
|
||||
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
|
||||
weighted subprompts.
|
||||
|
||||
Useful class exports:
|
||||
@ -69,6 +69,12 @@ class FlattenedPrompt():
|
||||
return len(self.children) == 0 or \
|
||||
(len(self.children) == 1 and len(self.children[0].text) == 0)
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return any(
|
||||
[issubclass(type(x), CrossAttentionControlledFragment) for x in self.children]
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"FlattenedPrompt:{self.children}"
|
||||
def __eq__(self, other):
|
||||
@ -240,6 +246,12 @@ class Blend():
|
||||
self.weights = weights
|
||||
self.normalize_weights = normalize_weights
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
# blends cannot cross-attention control
|
||||
return False
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
|
||||
def __eq__(self, other):
|
||||
@ -277,8 +289,8 @@ class PromptParser():
|
||||
|
||||
return self.flatten(root[0])
|
||||
|
||||
def parse_legacy_blend(self, text: str) -> Optional[Blend]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False)
|
||||
def parse_legacy_blend(self, text: str, skip_normalize: bool) -> Optional[Blend]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
@ -287,7 +299,7 @@ class PromptParser():
|
||||
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||
|
||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
|
||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
|
||||
|
||||
|
||||
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
|
||||
@ -641,27 +653,3 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||
|
||||
|
||||
# shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
def log_tokenization(text, model, display_label=None):
|
||||
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace('</w>', ' ')
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if i < model.cond_stage_model.max_length:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
||||
)
|
||||
|
@ -53,7 +53,6 @@ COMMANDS = (
|
||||
'--codeformer_fidelity','-cf',
|
||||
'--upscale','-U',
|
||||
'-save_orig','--save_original',
|
||||
'--skip_normalize','-x',
|
||||
'--log_tokenization','-t',
|
||||
'--hires_fix',
|
||||
'--inpaint_replace','-r',
|
||||
@ -117,19 +116,19 @@ class Completer(object):
|
||||
# extensions defined, so go directly into path completion mode
|
||||
if self.extensions is not None:
|
||||
self.matches = self._path_completions(text, state, self.extensions)
|
||||
|
||||
|
||||
# looking for an image file
|
||||
elif re.search(path_regexp,buffer):
|
||||
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
|
||||
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
|
||||
|
||||
# looking for a seed
|
||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||
self.matches= self._seed_completions(text,state)
|
||||
|
||||
elif re.search('<[\w-]*$',buffer):
|
||||
elif re.search('<[\w-]*$',buffer):
|
||||
self.matches= self._concept_completions(text,state)
|
||||
|
||||
|
||||
# looking for a model
|
||||
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state)
|
||||
@ -227,7 +226,7 @@ class Completer(object):
|
||||
if h_len < 1:
|
||||
print('<empty history>')
|
||||
return
|
||||
|
||||
|
||||
for i in range(0,h_len):
|
||||
line = self.get_history_item(i+1)
|
||||
if match and match not in line:
|
||||
@ -367,7 +366,7 @@ class DummyCompleter(Completer):
|
||||
def __init__(self,options):
|
||||
super().__init__(options)
|
||||
self.history = list()
|
||||
|
||||
|
||||
def add_history(self,line):
|
||||
self.history.append(line)
|
||||
|
||||
|
@ -7,11 +7,9 @@ import torch
|
||||
import diffusers
|
||||
from torch import nn
|
||||
|
||||
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
|
||||
class Arguments:
|
||||
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
||||
"""
|
||||
@ -68,11 +66,11 @@ class Context:
|
||||
self.clear_requests(cleanup=True)
|
||||
|
||||
def register_cross_attention_modules(self, model):
|
||||
for name,module in get_attention_modules(model, CrossAttentionType.SELF):
|
||||
for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||
if name in self.self_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
self.self_cross_attention_module_identifiers.append(name)
|
||||
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
if name in self.tokens_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
@ -175,6 +173,135 @@ class Context:
|
||||
map_dict[offset] = slice.to('cpu')
|
||||
|
||||
|
||||
|
||||
class InvokeAICrossAttentionMixin:
|
||||
"""
|
||||
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||
and dymamic slicing strategy selection.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
self.attention_slice_wrangler = None
|
||||
self.slicing_strategy_getter = None
|
||||
self.attention_slice_calculated_callback = None
|
||||
|
||||
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
||||
'''
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||
`suggested_attention_slice` is the default-calculated attention slice
|
||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
'''
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
||||
self.slicing_strategy_getter = getter
|
||||
|
||||
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
||||
self.attention_slice_calculated_callback = callback
|
||||
|
||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
# calculate attention slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
attention_slice_wrangler = self.attention_slice_wrangler
|
||||
if attention_slice_wrangler is not None:
|
||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
if self.attention_slice_calculated_callback is not None:
|
||||
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
||||
|
||||
hidden_states = torch.bmm(attention_slice, value)
|
||||
return hidden_states
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||
slicing_strategy_getter = self.slicing_strategy_getter
|
||||
if slicing_strategy_getter is not None:
|
||||
(dim, slice_size) = slicing_strategy_getter(self)
|
||||
if dim is not None:
|
||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||
if dim == 0:
|
||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||
elif dim == 1:
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
|
||||
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == 'mps' or q.device.type == 'cpu':
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
|
||||
|
||||
def remove_cross_attention_control(model):
|
||||
remove_attention_function(model)
|
||||
|
||||
@ -210,8 +337,7 @@ def setup_cross_attention_control(model, context: Context):
|
||||
inject_attention_function(model, context)
|
||||
|
||||
|
||||
def get_attention_modules(model, which: CrossAttentionType):
|
||||
# cross_attention_class: type = ldm.modules.attention.CrossAttention
|
||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
||||
@ -220,7 +346,7 @@ def get_attention_modules(model, which: CrossAttentionType):
|
||||
expected_count = 16
|
||||
if cross_attention_modules_in_model_count != expected_count:
|
||||
# non-fatal error but .swap() won't work.
|
||||
print(f"Error! CrossAttentionControl found an unexpected number of InvokeAICrossAttention modules in the model " +
|
||||
print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " +
|
||||
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " +
|
||||
f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " +
|
||||
f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " +
|
||||
@ -266,7 +392,7 @@ def inject_attention_function(unet, context: Context):
|
||||
|
||||
return attention_slice
|
||||
|
||||
cross_attention_modules = get_attention_modules(unet, CrossAttentionType.TOKENS) + get_attention_modules(unet, CrossAttentionType.SELF)
|
||||
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
module.identifier = identifier
|
||||
try:
|
||||
@ -282,7 +408,7 @@ def inject_attention_function(unet, context: Context):
|
||||
|
||||
|
||||
def remove_attention_function(unet):
|
||||
cross_attention_modules = get_attention_modules(unet, CrossAttentionType.TOKENS) + get_attention_modules(unet, CrossAttentionType.SELF)
|
||||
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
try:
|
||||
# clear wrangler callback
|
||||
@ -316,7 +442,6 @@ def get_mem_free_total(device):
|
||||
return mem_free_total
|
||||
|
||||
|
||||
|
||||
class InvokeAICrossAttentionMixin:
|
||||
"""
|
||||
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
||||
@ -452,4 +577,3 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention,
|
||||
hidden_states = self.reshape_batch_dim_to_heads(damian_result)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
95
ldm/models/diffusion/cross_attention_map_saving.py
Normal file
95
ldm/models/diffusion/cross_attention_map_saving.py
Normal file
@ -0,0 +1,95 @@
|
||||
import math
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import get_cross_attention_modules, CrossAttentionType
|
||||
|
||||
|
||||
class AttentionMapSaver():
|
||||
|
||||
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
||||
self.token_ids = token_ids
|
||||
self.latents_shape = latents_shape
|
||||
#self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||
self.collated_maps = {}
|
||||
|
||||
def clear_maps(self):
|
||||
self.collated_maps = {}
|
||||
|
||||
def add_attention_maps(self, maps: torch.Tensor, key: str):
|
||||
"""
|
||||
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
|
||||
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
|
||||
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
|
||||
:return: None
|
||||
"""
|
||||
key_and_size = f'{key}_{maps.shape[1]}'
|
||||
|
||||
# extract desired tokens
|
||||
maps = maps[:, :, self.token_ids]
|
||||
|
||||
# merge attention heads to a single map per token
|
||||
maps = torch.sum(maps, 0)
|
||||
|
||||
# store
|
||||
if key_and_size not in self.collated_maps:
|
||||
self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu')
|
||||
self.collated_maps[key_and_size] += maps.cpu()
|
||||
|
||||
def write_maps_to_disk(self, path: str):
|
||||
pil_image = self.get_stacked_maps_image()
|
||||
pil_image.save(path, 'PNG')
|
||||
|
||||
def get_stacked_maps_image(self) -> PIL.Image:
|
||||
"""
|
||||
Scale all collected attention maps to the same size, blend them together and return as an image.
|
||||
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
||||
"""
|
||||
num_tokens = len(self.token_ids)
|
||||
if num_tokens == 0:
|
||||
return None
|
||||
|
||||
latents_height = self.latents_shape[0]
|
||||
latents_width = self.latents_shape[1]
|
||||
|
||||
merged = None
|
||||
|
||||
for key, maps in self.collated_maps.items():
|
||||
|
||||
# maps has shape [(H*W), N] for N tokens
|
||||
# but we want [N, H, W]
|
||||
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||
this_maps_height = int(float(latents_height) * this_scale_factor)
|
||||
this_maps_width = int(float(latents_width) * this_scale_factor)
|
||||
# and we need to do some dimension juggling
|
||||
maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width])
|
||||
|
||||
# scale to output size if necessary
|
||||
if this_scale_factor != 1:
|
||||
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
||||
|
||||
# normalize
|
||||
maps_min = torch.min(maps)
|
||||
maps_range = torch.max(maps) - maps_min
|
||||
#print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
||||
maps_normalized = (maps - maps_min) / maps_range
|
||||
# expand to (-0.1, 1.1) and clamp
|
||||
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
||||
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
||||
|
||||
# merge together, producing a vertical stack
|
||||
maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width])
|
||||
|
||||
if merged is None:
|
||||
merged = maps_stacked
|
||||
else:
|
||||
# screen blend
|
||||
merged = 1 - (1 - maps_stacked)*(1 - merged)
|
||||
|
||||
if merged is None:
|
||||
return None
|
||||
|
||||
merged_bytes = merged.mul(0xff).byte()
|
||||
return PIL.Image.fromarray(merged_bytes.numpy(), mode='L')
|
@ -4,6 +4,7 @@ import k_diffusion as K
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
from .sampler import Sampler
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
@ -36,6 +37,7 @@ class CFGDenoiser(nn.Module):
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
@ -106,12 +108,12 @@ class KSampler(Sampler):
|
||||
else:
|
||||
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
|
||||
self.sigmas = self.karras_sigmas
|
||||
|
||||
|
||||
# ALERT: We are completely overriding the sample() method in the base class, which
|
||||
# means that inpainting will not work. To get this to work we need to be able to
|
||||
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
|
||||
# in the lstein/k-diffusion branch.
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
@ -145,7 +147,7 @@ class KSampler(Sampler):
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
return x0
|
||||
|
||||
|
||||
# Most of these arguments are ignored and are only present for compatibility with
|
||||
# other samples
|
||||
@torch.no_grad()
|
||||
@ -158,6 +160,7 @@ class KSampler(Sampler):
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
attention_maps_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
@ -171,7 +174,7 @@ class KSampler(Sampler):
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
extra_conditioning_info=None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
||||
threshold = 0,
|
||||
perlin = 0,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
@ -204,6 +207,12 @@ class KSampler(Sampler):
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||
|
||||
attention_map_token_ids = range(1, extra_conditioning_info.tokens_count_including_eos_bos - 1)
|
||||
attention_maps_saver = None if attention_maps_callback is None else AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
||||
if attention_maps_callback is not None:
|
||||
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
|
||||
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
@ -217,6 +226,8 @@ class KSampler(Sampler):
|
||||
),
|
||||
None,
|
||||
)
|
||||
if attention_maps_callback is not None:
|
||||
attention_maps_callback(attention_maps_saver)
|
||||
return sampling_result
|
||||
|
||||
# this code will support inpainting if and when ksampler API modified or
|
||||
@ -248,7 +259,7 @@ class KSampler(Sampler):
|
||||
# terrible, confusing names here
|
||||
steps = self.ddim_num_steps
|
||||
t_enc = self.t_enc
|
||||
|
||||
|
||||
# sigmas is a full steps in length, but t_enc might
|
||||
# be less. We start in the middle of the sigma array
|
||||
# and work our way to the end after t_enc steps.
|
||||
@ -280,7 +291,7 @@ class KSampler(Sampler):
|
||||
return x_T + x
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def prepare_to_sample(self,t_enc,**kwargs):
|
||||
self.t_enc = t_enc
|
||||
self.model_wrap = None
|
||||
|
@ -5,8 +5,8 @@ from typing import Callable, Optional, Union
|
||||
import torch
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||
remove_cross_attention_control, setup_cross_attention_control, Context
|
||||
from ldm.modules.attention import get_mem_free_total
|
||||
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, CrossAttentionType
|
||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
@ -21,7 +21,8 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
|
||||
class ExtraConditioningInfo:
|
||||
def __init__(self, cross_attention_control_args: Optional[Arguments]):
|
||||
def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]):
|
||||
self.tokens_count_including_eos_bos = tokens_count_including_eos_bos
|
||||
self.cross_attention_control_args = cross_attention_control_args
|
||||
|
||||
@property
|
||||
@ -53,7 +54,25 @@ class InvokeAIDiffuserComponent:
|
||||
self.cross_attention_control_context = None
|
||||
remove_cross_attention_control(self.model)
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
if dim is not None:
|
||||
# sliced tokens attention map saving is not implemented
|
||||
return
|
||||
saver.add_attention_maps(slice, key)
|
||||
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for identifier, module in tokens_cross_attention_modules:
|
||||
key = ('down' if identifier.startswith('down') else
|
||||
'up' if identifier.startswith('up') else
|
||||
'mid')
|
||||
module.set_attention_slice_calculated_callback(
|
||||
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key))
|
||||
|
||||
def remove_attention_map_saving(self):
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for _, module in tokens_cross_attention_modules:
|
||||
module.set_attention_slice_calculated_callback(None)
|
||||
|
||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||
unconditioning: Union[torch.Tensor,dict],
|
||||
|
@ -10,8 +10,6 @@ from einops import rearrange, repeat
|
||||
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin
|
||||
from ldm.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
import psutil
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
@ -165,10 +163,10 @@ def get_mem_free_total(device):
|
||||
return mem_free_total
|
||||
|
||||
class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
||||
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
print(f"Warning! ldm.modules.attention.CrossAttention is no longer being maintained. Please use InvokeAICrossAttention instead.")
|
||||
super().__init__()
|
||||
InvokeAICrossAttentionMixin.__init__(self)
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
@ -184,7 +182,6 @@ class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
@ -196,7 +193,7 @@ class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
# prevent scale being applied twice
|
||||
# don't apply scale twice
|
||||
cached_scale = self.scale
|
||||
self.scale = 1
|
||||
r = self.get_invokeai_attention_mem_efficient(q, k, v)
|
||||
|
Loading…
Reference in New Issue
Block a user