Save and display per-token attention maps (#1866)

* attention maps saving to /tmp

* tidy up diffusers branch backporting of cross attention refactoring

* base64-encoding the attention maps image for generationResult

* cleanup/refactor conditioning.py

* attention maps and tokens being sent to web UI

* attention maps: restrict count to actual token count and improve robustness

* add argument type hint to image_to_dataURL function

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>

Co-authored-by: damian <git@damianstewart.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Damian Stewart 2022-12-10 15:57:41 +01:00 committed by GitHub
parent 55132f6463
commit 786b8878d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 636 additions and 346 deletions

View File

@ -18,9 +18,11 @@ from PIL.Image import Image as ImageType
from uuid import uuid4 from uuid import uuid4
from threading import Event from threading import Event
from ldm.generate import Generate
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
from ldm.invoke.prompt_parser import split_weighted_subprompts from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend
from ldm.invoke.generator.inpaint import infill_methods from ldm.invoke.generator.inpaint import infill_methods
from backend.modules.parameters import parameters_to_command from backend.modules.parameters import parameters_to_command
@ -39,7 +41,7 @@ if not os.path.isabs(args.outdir):
class InvokeAIWebServer: class InvokeAIWebServer:
def __init__(self, generate, gfpgan, codeformer, esrgan) -> None: def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None:
self.host = args.host self.host = args.host
self.port = args.port self.port = args.port
@ -905,16 +907,13 @@ class InvokeAIWebServer:
}, },
) )
if generation_parameters["progress_latents"]: if generation_parameters["progress_latents"]:
image = self.generate.sample_to_lowres_estimated_image(sample) image = self.generate.sample_to_lowres_estimated_image(sample)
(width, height) = image.size (width, height) = image.size
width *= 8 width *= 8
height *= 8 height *= 8
buffered = io.BytesIO() img_base64 = image_to_dataURL(image)
image.save(buffered, format="PNG")
img_base64 = "data:image/png;base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
self.socketio.emit( self.socketio.emit(
"intermediateResult", "intermediateResult",
{ {
@ -932,7 +931,7 @@ class InvokeAIWebServer:
self.socketio.emit("progressUpdate", progress.to_formatted_dict()) self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0) eventlet.sleep(0)
def image_done(image, seed, first_seed): def image_done(image, seed, first_seed, attention_maps_image=None):
if self.canceled.is_set(): if self.canceled.is_set():
raise CanceledException raise CanceledException
@ -1094,6 +1093,12 @@ class InvokeAIWebServer:
self.socketio.emit("progressUpdate", progress.to_formatted_dict()) self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0) eventlet.sleep(0)
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
tokens = None if type(parsed_prompt) is Blend else \
get_tokens_for_prompt(self.generate.model, parsed_prompt)
attention_maps_image_base64_url = None if attention_maps_image is None \
else image_to_dataURL(attention_maps_image)
self.socketio.emit( self.socketio.emit(
"generationResult", "generationResult",
{ {
@ -1106,6 +1111,8 @@ class InvokeAIWebServer:
"height": height, "height": height,
"boundingBox": original_bounding_box, "boundingBox": original_bounding_box,
"generationMode": generation_parameters["generation_mode"], "generationMode": generation_parameters["generation_mode"],
"attentionMaps": attention_maps_image_base64_url,
"tokens": tokens,
}, },
) )
eventlet.sleep(0) eventlet.sleep(0)
@ -1117,7 +1124,7 @@ class InvokeAIWebServer:
self.generate.prompt2image( self.generate.prompt2image(
**generation_parameters, **generation_parameters,
step_callback=image_progress, step_callback=image_progress,
image_callback=image_done, image_callback=image_done
) )
except KeyboardInterrupt: except KeyboardInterrupt:
@ -1564,6 +1571,19 @@ def dataURL_to_image(dataURL: str) -> ImageType:
) )
return image return image
"""
Converts an image into a base64 image dataURL.
"""
def image_to_dataURL(image: ImageType) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_base64 = "data:image/png;base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
return image_base64
""" """
Converts a base64 image dataURL into bytes. Converts a base64 image dataURL into bytes.

View File

@ -20,6 +20,8 @@ import cv2
import skimage import skimage
from omegaconf import OmegaConf from omegaconf import OmegaConf
import ldm.invoke.conditioning
from ldm.invoke.generator.base import downsampling from ldm.invoke.generator.base import downsampling
from PIL import Image, ImageOps from PIL import Image, ImageOps
from torch import nn from torch import nn
@ -40,7 +42,7 @@ from ldm.invoke.model_cache import ModelCache
from ldm.invoke.seamless import configure_model_padding from ldm.invoke.seamless import configure_model_padding
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
from ldm.invoke.concepts_lib import Concepts from ldm.invoke.concepts_lib import Concepts
def fix_func(orig): def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw): def new_func(*args, **kw):
@ -235,7 +237,7 @@ class Generate:
except Exception: except Exception:
print('** An error was encountered while installing the safety checker:') print('** An error was encountered while installing the safety checker:')
print(traceback.format_exc()) print(traceback.format_exc())
def prompt2png(self, prompt, outdir, **kwargs): def prompt2png(self, prompt, outdir, **kwargs):
""" """
Takes a prompt and an output directory, writes out the requested number Takes a prompt and an output directory, writes out the requested number
@ -329,7 +331,7 @@ class Generate:
infill_method = infill_methods[0], # The infill method to use infill_method = infill_methods[0], # The infill method to use
force_outpaint: bool = False, force_outpaint: bool = False,
enable_image_debugging = False, enable_image_debugging = False,
**args, **args,
): # eat up additional cruft ): # eat up additional cruft
""" """
@ -372,7 +374,7 @@ class Generate:
def process_image(image,seed): def process_image(image,seed):
image.save(f{'images/seed.png'}) image.save(f{'images/seed.png'})
The code used to save images to a directory can be found in ldm/invoke/pngwriter.py. The code used to save images to a directory can be found in ldm/invoke/pngwriter.py.
It contains code to create the requested output directory, select a unique informative It contains code to create the requested output directory, select a unique informative
name for each image, and write the prompt into the PNG metadata. name for each image, and write the prompt into the PNG metadata.
""" """
@ -455,7 +457,7 @@ class Generate:
try: try:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec( uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model, prompt, model =self.model,
skip_normalize=skip_normalize, skip_normalize_legacy_blend=skip_normalize,
log_tokens =self.log_tokenization log_tokens =self.log_tokenization
) )
@ -589,7 +591,7 @@ class Generate:
seed = opt.seed or args.seed seed = opt.seed or args.seed
if seed is None or seed < 0: if seed is None or seed < 0:
seed = random.randrange(0, np.iinfo(np.uint32).max) seed = random.randrange(0, np.iinfo(np.uint32).max)
prompt = opt.prompt or args.prompt or '' prompt = opt.prompt or args.prompt or ''
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}') print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')
@ -607,8 +609,8 @@ class Generate:
# todo: cross-attention control # todo: cross-attention control
uc, c, extra_conditioning_info = get_uc_and_c_and_ec( uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model, prompt, model =self.model,
skip_normalize=opt.skip_normalize, skip_normalize_legacy_blend=opt.skip_normalize,
log_tokens =opt.log_tokenization log_tokens =ldm.invoke.conditioning.log_tokenization
) )
if tool in ('gfpgan','codeformer','upscale'): if tool in ('gfpgan','codeformer','upscale'):
@ -641,7 +643,7 @@ class Generate:
opt.seed = seed opt.seed = seed
opt.prompt = prompt opt.prompt = prompt
if len(extend_instructions) > 0: if len(extend_instructions) > 0:
restorer = Outcrop(image,self,) restorer = Outcrop(image,self,)
return restorer.process ( return restorer.process (
@ -683,7 +685,7 @@ class Generate:
image_callback = callback, image_callback = callback,
prefix = prefix prefix = prefix
) )
elif tool is None: elif tool is None:
print(f'* please provide at least one postprocessing option, such as -G or -U') print(f'* please provide at least one postprocessing option, such as -G or -U')
return None return None
@ -706,13 +708,13 @@ class Generate:
if embiggen is not None: if embiggen is not None:
return self._make_embiggen() return self._make_embiggen()
if inpainting_model_in_use: if inpainting_model_in_use:
return self._make_omnibus() return self._make_omnibus()
if ((init_image is not None) and (mask_image is not None)) or force_outpaint: if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
return self._make_inpaint() return self._make_inpaint()
if init_image is not None: if init_image is not None:
return self._make_img2img() return self._make_img2img()
@ -743,7 +745,7 @@ class Generate:
if self._has_transparency(image): if self._has_transparency(image):
self._transparency_check_and_warning(image, mask, force_outpaint) self._transparency_check_and_warning(image, mask, force_outpaint)
init_mask = self._create_init_mask(image, width, height, fit=fit) init_mask = self._create_init_mask(image, width, height, fit=fit)
if (image.width * image.height) > (self.width * self.height) and self.size_matters: if (image.width * image.height) > (self.width * self.height) and self.size_matters:
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
self.size_matters = False self.size_matters = False
@ -759,7 +761,7 @@ class Generate:
if init_mask and invert_mask: if init_mask and invert_mask:
init_mask = ImageOps.invert(init_mask) init_mask = ImageOps.invert(init_mask)
return init_image,init_mask return init_image,init_mask
# lots o' repeated code here! Turn into a make_func() # lots o' repeated code here! Turn into a make_func()
@ -818,7 +820,7 @@ class Generate:
self.set_model(self.model_name) self.set_model(self.model_name)
def set_model(self,model_name): def set_model(self,model_name):
""" """
Given the name of a model defined in models.yaml, will load and initialize it Given the name of a model defined in models.yaml, will load and initialize it
and return the model object. Previously-used models will be cached. and return the model object. Previously-used models will be cached.
""" """
@ -830,7 +832,7 @@ class Generate:
if not cache.valid_model(model_name): if not cache.valid_model(model_name):
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file') print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
return self.model return self.model
cache.print_vram_usage() cache.print_vram_usage()
# have to get rid of all references to model in order # have to get rid of all references to model in order
@ -839,7 +841,7 @@ class Generate:
self.sampler = None self.sampler = None
self.generators = {} self.generators = {}
gc.collect() gc.collect()
model_data = cache.get_model(model_name) model_data = cache.get_model(model_name)
if model_data is None: # restore previous if model_data is None: # restore previous
model_data = cache.get_model(self.model_name) model_data = cache.get_model(self.model_name)
@ -852,7 +854,7 @@ class Generate:
# uncache generators so they pick up new models # uncache generators so they pick up new models
self.generators = {} self.generators = {}
seed_everything(random.randrange(0, np.iinfo(np.uint32).max)) seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
if self.embedding_path is not None: if self.embedding_path is not None:
self.model.embedding_manager.load( self.model.embedding_manager.load(
@ -901,7 +903,7 @@ class Generate:
image_callback = None, image_callback = None,
prefix = None, prefix = None,
): ):
for r in image_list: for r in image_list:
image, seed = r image, seed = r
try: try:
@ -911,7 +913,7 @@ class Generate:
if self.gfpgan is None: if self.gfpgan is None:
print('>> GFPGAN not found. Face restoration is disabled.') print('>> GFPGAN not found. Face restoration is disabled.')
else: else:
image = self.gfpgan.process(image, strength, seed) image = self.gfpgan.process(image, strength, seed)
if facetool == 'codeformer': if facetool == 'codeformer':
if self.codeformer is None: if self.codeformer is None:
print('>> CodeFormer not found. Face restoration is disabled.') print('>> CodeFormer not found. Face restoration is disabled.')

View File

@ -8,6 +8,7 @@ import time
import traceback import traceback
import yaml import yaml
from ldm.generate import Generate
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
from ldm.invoke.prompt_parser import PromptParser from ldm.invoke.prompt_parser import PromptParser
from ldm.invoke.readline import get_completer, Completer from ldm.invoke.readline import get_completer, Completer
@ -27,7 +28,7 @@ def main():
"""Initialize command-line parsers and the diffusion model""" """Initialize command-line parsers and the diffusion model"""
global infile global infile
print('* Initializing, be patient...') print('* Initializing, be patient...')
opt = Args() opt = Args()
args = opt.parse_args() args = opt.parse_args()
if not args: if not args:
@ -47,7 +48,7 @@ def main():
# alert - setting globals here # alert - setting globals here
Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.')) Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.'))
Globals.try_patchmatch = args.patchmatch Globals.try_patchmatch = args.patchmatch
print(f'>> InvokeAI runtime directory is "{Globals.root}"') print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# loading here to avoid long delays on startup # loading here to avoid long delays on startup
@ -281,7 +282,7 @@ def main_loop(gen, opt):
prefix = file_writer.unique_prefix() prefix = file_writer.unique_prefix()
step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None): def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None, attention_maps_image=None):
# note the seed is the seed of the current image # note the seed is the seed of the current image
# the first_seed is the original seed that noise is added to # the first_seed is the original seed that noise is added to
# when the -v switch is used to generate variations # when the -v switch is used to generate variations
@ -341,8 +342,8 @@ def main_loop(gen, opt):
filename, filename,
tool, tool,
formatted_dream_prompt, formatted_dream_prompt,
) )
if (not postprocessed) or opt.save_original: if (not postprocessed) or opt.save_original:
# only append to results if we didn't overwrite an earlier output # only append to results if we didn't overwrite an earlier output
results.append([path, formatted_dream_prompt]) results.append([path, formatted_dream_prompt])
@ -432,7 +433,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
add_embedding_terms(gen, completer) add_embedding_terms(gen, completer)
completer.add_history(command) completer.add_history(command)
operation = None operation = None
elif command.startswith('!models'): elif command.startswith('!models'):
gen.model_cache.print_models() gen.model_cache.print_models()
completer.add_history(command) completer.add_history(command)
@ -533,7 +534,7 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
completer.complete_extensions(('.yaml','.yml')) completer.complete_extensions(('.yaml','.yml'))
completer.linebuffer = 'configs/stable-diffusion/v1-inference.yaml' completer.linebuffer = 'configs/stable-diffusion/v1-inference.yaml'
done = False done = False
while not done: while not done:
new_config['config'] = input('Configuration file for this model: ') new_config['config'] = input('Configuration file for this model: ')
@ -564,7 +565,7 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
print('** Please enter a valid integer between 64 and 2048') print('** Please enter a valid integer between 64 and 2048')
make_default = input('Make this the default model? [n] ') in ('y','Y') make_default = input('Make this the default model? [n] ') in ('y','Y')
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default): if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
completer.add_model(model_name) completer.add_model(model_name)
@ -577,14 +578,14 @@ def del_config(model_name:str, gen, opt, completer):
gen.model_cache.commit(opt.conf) gen.model_cache.commit(opt.conf)
print(f'** {model_name} deleted') print(f'** {model_name} deleted')
completer.del_model(model_name) completer.del_model(model_name)
def edit_config(model_name:str, gen, opt, completer): def edit_config(model_name:str, gen, opt, completer):
config = gen.model_cache.config config = gen.model_cache.config
if model_name not in config: if model_name not in config:
print(f'** Unknown model {model_name}') print(f'** Unknown model {model_name}')
return return
print(f'\n>> Editing model {model_name} from configuration file {opt.conf}') print(f'\n>> Editing model {model_name} from configuration file {opt.conf}')
conf = config[model_name] conf = config[model_name]
@ -597,10 +598,10 @@ def edit_config(model_name:str, gen, opt, completer):
make_default = input('Make this the default model? [n] ') in ('y','Y') make_default = input('Make this the default model? [n] ') in ('y','Y')
completer.complete_extensions(None) completer.complete_extensions(None)
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default) write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False): def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
current_model = gen.model_name current_model = gen.model_name
op = 'modify' if clobber else 'import' op = 'modify' if clobber else 'import'
print('\n>> New configuration:') print('\n>> New configuration:')
if make_default: if make_default:
@ -623,7 +624,7 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False, mak
gen.model_cache.set_default_model(model_name) gen.model_cache.set_default_model(model_name)
gen.model_cache.commit(conf_path) gen.model_cache.commit(conf_path)
do_switch = input(f'Keep model loaded? [y]') do_switch = input(f'Keep model loaded? [y]')
if len(do_switch)==0 or do_switch[0] in ('y','Y'): if len(do_switch)==0 or do_switch[0] in ('y','Y'):
pass pass
@ -653,7 +654,7 @@ def do_postprocess (gen, opt, callback):
opt.prompt = opt.new_prompt opt.prompt = opt.new_prompt
else: else:
opt.prompt = None opt.prompt = None
if os.path.dirname(file_path) == '': #basename given if os.path.dirname(file_path) == '': #basename given
file_path = os.path.join(opt.outdir,file_path) file_path = os.path.join(opt.outdir,file_path)
@ -718,7 +719,7 @@ def add_postprocessing_to_metadata(opt,original_file,new_file,tool,command):
) )
meta['image']['postprocessing'] = pp meta['image']['postprocessing'] = pp
write_metadata(new_file,meta) write_metadata(new_file,meta)
def prepare_image_metadata( def prepare_image_metadata(
opt, opt,
prefix, prefix,
@ -789,28 +790,28 @@ def get_next_command(infile=None) -> str: # command string
print(f'#{command}') print(f'#{command}')
return command return command
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan): def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan):
print('\n* --web was specified, starting web server...') print('\n* --web was specified, starting web server...')
from backend.invoke_ai_web_server import InvokeAIWebServer from backend.invoke_ai_web_server import InvokeAIWebServer
# Change working directory to the stable-diffusion directory # Change working directory to the stable-diffusion directory
os.chdir( os.chdir(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
) )
invoke_ai_web_server = InvokeAIWebServer(generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan) invoke_ai_web_server = InvokeAIWebServer(generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan)
try: try:
invoke_ai_web_server.run() invoke_ai_web_server.run()
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
def add_embedding_terms(gen,completer): def add_embedding_terms(gen,completer):
''' '''
Called after setting the model, updates the autocompleter with Called after setting the model, updates the autocompleter with
any terms loaded by the embedding manager. any terms loaded by the embedding manager.
''' '''
completer.add_embedding_terms(gen.model.embedding_manager.list_terms()) completer.add_embedding_terms(gen.model.embedding_manager.list_terms())
def split_variations(variations_string) -> list: def split_variations(variations_string) -> list:
# shotgun parsing, woo # shotgun parsing, woo
parts = [] parts = []
@ -867,7 +868,7 @@ def make_step_callback(gen, opt, prefix):
image = gen.sample_to_image(img) image = gen.sample_to_image(img)
image.save(filename,'PNG') image.save(filename,'PNG')
return callback return callback
def retrieve_dream_command(opt,command,completer): def retrieve_dream_command(opt,command,completer):
''' '''
Given a full or partial path to a previously-generated image file, Given a full or partial path to a previously-generated image file,
@ -875,7 +876,7 @@ def retrieve_dream_command(opt,command,completer):
and pop it into the readline buffer (linux, Mac), or print out a comment and pop it into the readline buffer (linux, Mac), or print out a comment
for cut-and-paste (windows) for cut-and-paste (windows)
Given a wildcard path to a folder with image png files, Given a wildcard path to a folder with image png files,
will retrieve and format the dream command used to generate the images, will retrieve and format the dream command used to generate the images,
and save them to a file commands.txt for further processing and save them to a file commands.txt for further processing
''' '''
@ -911,7 +912,7 @@ def write_commands(opt, file_path:str, outfilepath:str):
except ValueError: except ValueError:
print(f'## "{basename}": unacceptable pattern') print(f'## "{basename}": unacceptable pattern')
return return
commands = [] commands = []
cmd = None cmd = None
for path in paths: for path in paths:
@ -940,7 +941,7 @@ def emergency_model_reconfigure():
print(' After reconfiguration is done, please relaunch invoke.py. ') print(' After reconfiguration is done, please relaunch invoke.py. ')
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print('configure_invokeai is launching....\n') print('configure_invokeai is launching....\n')
sys.argv = ['configure_invokeai','--interactive'] sys.argv = ['configure_invokeai','--interactive']
import configure_invokeai import configure_invokeai
configure_invokeai.main() configure_invokeai.main()

View File

@ -7,20 +7,46 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
''' '''
import re import re
from difflib import SequenceMatcher
from typing import Union from typing import Union
import torch import torch
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment
from ..models.diffusion import cross_attention_control from ..models.diffusion import cross_attention_control
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False): def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
prompt, negative_prompt = get_prompt_structure(prompt_string,
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
return conditioning
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> (
Union[FlattenedPrompt, Blend], FlattenedPrompt):
"""
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
"""
prompt, negative_prompt = _parse_prompt_string(prompt_string,
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
return prompt, negative_prompt
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]:
text_fragments = [x.text if type(x) is Fragment else
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
str(x))
for x in parsed_prompt.children]
text = " ".join(text_fragments)
tokens = model.cond_stage_model.tokenizer.tokenize(text)
return tokens
def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]:
# Extract Unconditioned Words From Prompt # Extract Unconditioned Words From Prompt
unconditioned_words = '' unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]' unconditional_regex = r'\[(.*?)\]'
@ -39,7 +65,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
pp = PromptParser() pp = PromptParser()
parsed_prompt: Union[FlattenedPrompt, Blend] = None parsed_prompt: Union[FlattenedPrompt, Blend] = None
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned) legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned, skip_normalize_legacy_blend)
if legacy_blend is not None: if legacy_blend is not None:
parsed_prompt = legacy_blend parsed_prompt = legacy_blend
else: else:
@ -47,129 +73,150 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0] parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0] parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
return parsed_prompt, parsed_negative_prompt
def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt,
model, log_tokens=False) \
-> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]:
"""
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
"""
if log_tokens: if log_tokens:
print(f">> Parsed prompt to {parsed_prompt}") print(f">> Parsed prompt to {parsed_prompt}")
print(f">> Parsed negative prompt to {parsed_negative_prompt}") print(f">> Parsed negative prompt to {parsed_negative_prompt}")
conditioning = None conditioning = None
cac_args:cross_attention_control.Arguments = None cac_args: cross_attention_control.Arguments = None
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt conditioning = _get_conditioning_for_blend(model, parsed_prompt, log_tokens)
embeddings_to_blend = None elif type(parsed_prompt) is FlattenedPrompt:
for i,flattened_prompt in enumerate(blend.prompts): if parsed_prompt.wants_cross_attention_control:
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, conditioning, cac_args = _get_conditioning_for_cross_attention_control(model, parsed_prompt, log_tokens)
flattened_prompt,
log_tokens=log_tokens,
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
blend.weights,
normalize=blend.normalize_weights)
else:
flattened_prompt: FlattenedPrompt = parsed_prompt
wants_cross_attention_control = type(flattened_prompt) is not Blend \
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
if wants_cross_attention_control:
original_prompt = FlattenedPrompt()
edited_prompt = FlattenedPrompt()
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
original_token_count = 0
edited_token_count = 0
edit_options = []
edit_opcodes = []
# beginning of sequence
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
for fragment in flattened_prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited)
to_replace_token_count = get_tokens_length(model, fragment.original)
replacement_token_count = get_tokens_length(model, fragment.edited)
edit_opcodes.append(('replace',
original_token_count, original_token_count + to_replace_token_count,
edited_token_count, edited_token_count + replacement_token_count
))
original_token_count += to_replace_token_count
edited_token_count += replacement_token_count
edit_options.append(fragment.options)
#elif type(fragment) is CrossAttentionControlAppend:
# edited_prompt.append(fragment.fragment)
else:
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
count = get_tokens_length(model, [fragment])
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
edit_options.append(None)
original_token_count += count
edited_token_count += count
# end of sequence
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
original_prompt,
log_tokens=log_tokens,
log_display_label="(.swap originals)")
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
edited_prompt,
log_tokens=log_tokens,
log_display_label="(.swap replacements)")
conditioning = original_embeddings
edited_conditioning = edited_embeddings
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = cross_attention_control.Arguments(
edited_conditioning = edited_conditioning,
edit_opcodes = edit_opcodes,
edit_options = edit_options
)
else: else:
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, conditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
flattened_prompt, parsed_prompt,
log_tokens=log_tokens, log_tokens=log_tokens,
log_display_label="(prompt)") log_display_label="(prompt)")
else:
raise ValueError(f"parsed_prompt is '{type(parsed_prompt)}' which is not a supported prompt type")
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, unconditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
parsed_negative_prompt, parsed_negative_prompt,
log_tokens=log_tokens, log_tokens=log_tokens,
log_display_label="(unconditioning)") log_display_label="(unconditioning)")
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
# hybrid conditioning is in play # hybrid conditioning is in play
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning) unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning)
if cac_args is not None: if cac_args is not None:
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.") print(
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
cac_args = None cac_args = None
eos_token_index = 1
if type(parsed_prompt) is not Blend:
tokens = get_tokens_for_prompt(model, parsed_prompt)
eos_token_index = len(tokens)+1
return ( return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo( unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=eos_token_index + 1,
cross_attention_control_args=cac_args cross_attention_control_args=cac_args
) )
) )
def build_token_edit_opcodes(original_tokens, edited_tokens): def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True):
original_tokens = original_tokens.cpu().numpy()[0] original_prompt = FlattenedPrompt()
edited_tokens = edited_tokens.cpu().numpy()[0] edited_prompt = FlattenedPrompt()
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
original_token_count = 0
edited_token_count = 0
edit_options = []
edit_opcodes = []
# beginning of sequence
edit_opcodes.append(
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
for fragment in prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited)
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes() to_replace_token_count = _get_tokens_length(model, fragment.original)
replacement_token_count = _get_tokens_length(model, fragment.edited)
edit_opcodes.append(('replace',
original_token_count, original_token_count + to_replace_token_count,
edited_token_count, edited_token_count + replacement_token_count
))
original_token_count += to_replace_token_count
edited_token_count += replacement_token_count
edit_options.append(fragment.options)
# elif type(fragment) is CrossAttentionControlAppend:
# edited_prompt.append(fragment.fragment)
else:
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None): count = _get_tokens_length(model, [fragment])
edit_opcodes.append(('equal', original_token_count, original_token_count + count, edited_token_count,
edited_token_count + count))
edit_options.append(None)
original_token_count += count
edited_token_count += count
# end of sequence
edit_opcodes.append(
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
original_embeddings, original_tokens = _get_embeddings_and_tokens_for_prompt(model,
original_prompt,
log_tokens=log_tokens,
log_display_label="(.swap originals)")
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = _get_embeddings_and_tokens_for_prompt(model,
edited_prompt,
log_tokens=log_tokens,
log_display_label="(.swap replacements)")
conditioning = original_embeddings
edited_conditioning = edited_embeddings
# print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = cross_attention_control.Arguments(
edited_conditioning=edited_conditioning,
edit_opcodes=edit_opcodes,
edit_options=edit_options
)
return conditioning, cac_args
def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
embeddings_to_blend = None
for i, flattened_prompt in enumerate(blend.prompts):
this_embedding, _ = _get_embeddings_and_tokens_for_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
blend.weights,
normalize=blend.normalize_weights)
return conditioning
def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False,
log_display_label: str = None):
if type(flattened_prompt) is not FlattenedPrompt: if type(flattened_prompt) is not FlattenedPrompt:
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead") raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
fragments = [x.text for x in flattened_prompt.children] fragments = [x.text for x in flattened_prompt.children]
@ -181,12 +228,14 @@ def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: Fl
return embeddings, tokens return embeddings, tokens
def get_tokens_length(model, fragments: list[Fragment]):
def _get_tokens_length(model, fragments: list[Fragment]):
fragment_texts = [x.text for x in fragments] fragment_texts = [x.text for x in fragments]
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False) tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
return sum([len(x) for x in tokens]) return sum([len(x) for x in tokens])
def flatten_hybrid_conditioning(uncond, cond):
def _flatten_hybrid_conditioning(uncond, cond):
''' '''
This handles the choice between a conditional conditioning This handles the choice between a conditional conditioning
that is a tensor (used by cross attention) vs one that has additional that is a tensor (used by cross attention) vs one that has additional
@ -205,4 +254,29 @@ def flatten_hybrid_conditioning(uncond, cond):
cond_flattened[k] = torch.cat([uncond[k], cond[k]]) cond_flattened[k] = torch.cat([uncond[k], cond[k]])
return uncond, cond_flattened return uncond, cond_flattened
def log_tokenization(text, model, display_label=None):
""" shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
"""
tokens = model.cond_stage_model.tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens - usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -14,6 +14,7 @@ import cv2 as cv
from einops import rearrange, repeat from einops import rearrange, repeat
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from ldm.invoke.devices import choose_autocast from ldm.invoke.devices import choose_autocast
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ldm.util import rand_perlin_2d from ldm.util import rand_perlin_2d
downsampling = 8 downsampling = 8
@ -51,9 +52,12 @@ class Generator():
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None, safety_checker:dict=None,
attention_maps_callback = None,
**kwargs): **kwargs):
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
self.safety_checker = safety_checker self.safety_checker = safety_checker
attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
make_image = self.get_make_image( make_image = self.get_make_image(
prompt, prompt,
sampler = sampler, sampler = sampler,
@ -63,6 +67,7 @@ class Generator():
step_callback = step_callback, step_callback = step_callback,
threshold = threshold, threshold = threshold,
perlin = perlin, perlin = perlin,
attention_maps_callback = attention_maps_callback,
**kwargs **kwargs
) )
results = [] results = []
@ -98,12 +103,12 @@ class Generator():
results.append([image, seed]) results.append([image, seed])
if image_callback is not None: if image_callback is not None:
image_callback(image, seed, first_seed=first_seed) image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_images[-1])
seed = self.new_seed() seed = self.new_seed()
return results return results
def sample_to_image(self,samples)->Image.Image: def sample_to_image(self,samples)->Image.Image:
""" """
Given samples returned from a sampler, converts Given samples returned from a sampler, converts
@ -166,12 +171,12 @@ class Generator():
blurred_init_mask = pil_init_mask blurred_init_mask = pil_init_mask
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1]) multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
# Paste original on color-corrected generation (using blurred mask) # Paste original on color-corrected generation (using blurred mask)
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask) matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask)
return matched_result return matched_result
def sample_to_lowres_estimated_image(self,samples): def sample_to_lowres_estimated_image(self,samples):
# origingally adapted from code by @erucipe and @keturn here: # origingally adapted from code by @erucipe and @keturn here:
@ -219,11 +224,11 @@ class Generator():
(txt2img) or from the latent image (img2img, inpaint) (txt2img) or from the latent image (img2img, inpaint)
""" """
raise NotImplementedError("get_noise() must be implemented in a descendent class") raise NotImplementedError("get_noise() must be implemented in a descendent class")
def get_perlin_noise(self,width,height): def get_perlin_noise(self,width,height):
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device) return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
def new_seed(self): def new_seed(self):
self.seed = random.randrange(0, np.iinfo(np.uint32).max) self.seed = random.randrange(0, np.iinfo(np.uint32).max)
return self.seed return self.seed
@ -325,4 +330,4 @@ class Generator():
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)
image.save(filepath,'PNG') image.save(filepath,'PNG')

View File

@ -14,7 +14,9 @@ class Txt2Img(Generator):
@torch.no_grad() @torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs): conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
attention_maps_callback=None,
**kwargs):
""" """
Returns a function returning an image derived from the prompt and the initial image Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it Return value depends on the seed at the time you call it
@ -33,7 +35,7 @@ class Txt2Img(Generator):
if self.free_gpu_mem and self.model.model.device != self.model.device: if self.free_gpu_mem and self.model.model.device != self.model.device:
self.model.model.to(self.model.device) self.model.model.to(self.model.device)
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
samples, _ = sampler.sample( samples, _ = sampler.sample(
@ -49,6 +51,7 @@ class Txt2Img(Generator):
eta = ddim_eta, eta = ddim_eta,
img_callback = step_callback, img_callback = step_callback,
threshold = threshold, threshold = threshold,
attention_maps_callback = attention_maps_callback,
) )
if self.free_gpu_mem: if self.free_gpu_mem:

View File

@ -3,7 +3,7 @@ from typing import Union, Optional
import re import re
import pyparsing as pp import pyparsing as pp
''' '''
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors. This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
weighted subprompts. weighted subprompts.
Useful class exports: Useful class exports:
@ -69,6 +69,12 @@ class FlattenedPrompt():
return len(self.children) == 0 or \ return len(self.children) == 0 or \
(len(self.children) == 1 and len(self.children[0].text) == 0) (len(self.children) == 1 and len(self.children[0].text) == 0)
@property
def wants_cross_attention_control(self):
return any(
[issubclass(type(x), CrossAttentionControlledFragment) for x in self.children]
)
def __repr__(self): def __repr__(self):
return f"FlattenedPrompt:{self.children}" return f"FlattenedPrompt:{self.children}"
def __eq__(self, other): def __eq__(self, other):
@ -240,6 +246,12 @@ class Blend():
self.weights = weights self.weights = weights
self.normalize_weights = normalize_weights self.normalize_weights = normalize_weights
@property
def wants_cross_attention_control(self):
# blends cannot cross-attention control
return False
def __repr__(self): def __repr__(self):
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}" return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
def __eq__(self, other): def __eq__(self, other):
@ -277,8 +289,8 @@ class PromptParser():
return self.flatten(root[0]) return self.flatten(root[0])
def parse_legacy_blend(self, text: str) -> Optional[Blend]: def parse_legacy_blend(self, text: str, skip_normalize: bool) -> Optional[Blend]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False) weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1: if len(weighted_subprompts) <= 1:
return None return None
strings = [x[0] for x in weighted_subprompts] strings = [x[0] for x in weighted_subprompts]
@ -287,7 +299,7 @@ class PromptParser():
parsed_conjunctions = [self.parse_conjunction(x) for x in strings] parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True) return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
def flatten(self, root: Conjunction, verbose = False) -> Conjunction: def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
@ -641,27 +653,3 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
return [(x[0], equal_weight) for x in parsed_prompts] return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts] return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, display_label=None):
tokens = model.cond_stage_model.tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -53,7 +53,6 @@ COMMANDS = (
'--codeformer_fidelity','-cf', '--codeformer_fidelity','-cf',
'--upscale','-U', '--upscale','-U',
'-save_orig','--save_original', '-save_orig','--save_original',
'--skip_normalize','-x',
'--log_tokenization','-t', '--log_tokenization','-t',
'--hires_fix', '--hires_fix',
'--inpaint_replace','-r', '--inpaint_replace','-r',
@ -117,19 +116,19 @@ class Completer(object):
# extensions defined, so go directly into path completion mode # extensions defined, so go directly into path completion mode
if self.extensions is not None: if self.extensions is not None:
self.matches = self._path_completions(text, state, self.extensions) self.matches = self._path_completions(text, state, self.extensions)
# looking for an image file # looking for an image file
elif re.search(path_regexp,buffer): elif re.search(path_regexp,buffer):
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer) do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut) self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
# looking for a seed # looking for a seed
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer): elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
self.matches= self._seed_completions(text,state) self.matches= self._seed_completions(text,state)
elif re.search('<[\w-]*$',buffer): elif re.search('<[\w-]*$',buffer):
self.matches= self._concept_completions(text,state) self.matches= self._concept_completions(text,state)
# looking for a model # looking for a model
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer): elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
self.matches= self._model_completions(text, state) self.matches= self._model_completions(text, state)
@ -227,7 +226,7 @@ class Completer(object):
if h_len < 1: if h_len < 1:
print('<empty history>') print('<empty history>')
return return
for i in range(0,h_len): for i in range(0,h_len):
line = self.get_history_item(i+1) line = self.get_history_item(i+1)
if match and match not in line: if match and match not in line:
@ -367,7 +366,7 @@ class DummyCompleter(Completer):
def __init__(self,options): def __init__(self,options):
super().__init__(options) super().__init__(options)
self.history = list() self.history = list()
def add_history(self,line): def add_history(self,line):
self.history.append(line) self.history.append(line)

View File

@ -1,12 +1,14 @@
import enum import enum
from typing import Optional import math
from typing import Optional, Callable
import psutil
import torch import torch
from torch import nn
# adapted from bloc97's CrossAttentionControl colab # adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl # https://github.com/bloc97/CrossAttentionControl
class Arguments: class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict): def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
""" """
@ -63,9 +65,13 @@ class Context:
self.clear_requests(cleanup=True) self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model): def register_cross_attention_modules(self, model):
for name,module in get_attention_modules(model, CrossAttentionType.SELF): for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF):
if name in self.self_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
self.self_cross_attention_module_identifiers.append(name) self.self_cross_attention_module_identifiers.append(name)
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS): for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
if name in self.tokens_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
self.tokens_cross_attention_module_identifiers.append(name) self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType): def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
@ -166,6 +172,135 @@ class Context:
map_dict[offset] = slice.to('cpu') map_dict[offset] = slice.to('cpu')
class InvokeAICrossAttentionMixin:
"""
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
and dymamic slicing strategy selection.
"""
def __init__(self):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.attention_slice_wrangler = None
self.slicing_strategy_getter = None
self.attention_slice_calculated_callback = None
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
'''
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current CrossAttention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
Pass None to use the default attention calculation.
:return:
'''
self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
self.slicing_strategy_getter = getter
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
self.attention_slice_calculated_callback = callback
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
# calculate attention scores
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
# calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
if self.attention_slice_calculated_callback is not None:
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
hidden_states = torch.bmm(attention_slice, value)
return hidden_states
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
slicing_strategy_getter = self.slicing_strategy_getter
if slicing_strategy_getter is not None:
(dim, slice_size) = slicing_strategy_getter(self)
if dim is not None:
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
if dim == 0:
return self.einsum_op_slice_dim0(q, k, v, slice_size)
elif dim == 1:
return self.einsum_op_slice_dim1(q, k, v, slice_size)
# fallback for when there is no saved strategy, or saved strategy does not slice
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_invokeai_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda':
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v)
if q.device.type == 'mps' or q.device.type == 'cpu':
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
def remove_cross_attention_control(model): def remove_cross_attention_control(model):
remove_attention_function(model) remove_attention_function(model)
@ -187,7 +322,7 @@ def setup_cross_attention_control(model, context: Context):
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention # mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length) mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long) indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long) indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length: if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
@ -201,10 +336,23 @@ def setup_cross_attention_control(model, context: Context):
inject_attention_function(model, context) inject_attention_function(model, context)
def get_attention_modules(model, which: CrossAttentionType): def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
cross_attention_class: type = InvokeAICrossAttentionMixin
# cross_attention_class: type = InvokeAIDiffusersCrossAttention
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
return [(name,module) for name, module in model.named_modules() if attention_module_tuples = [(name,module) for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name] isinstance(module, cross_attention_class) and which_attn in name]
cross_attention_modules_in_model_count = len(attention_module_tuples)
expected_count = 16
if cross_attention_modules_in_model_count != expected_count:
# non-fatal error but .swap() won't work.
print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " +
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " +
f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " +
f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " +
f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " +
f"work properly until it is fixed.")
return attention_module_tuples
def inject_attention_function(unet, context: Context): def inject_attention_function(unet, context: Context):
@ -244,19 +392,52 @@ def inject_attention_function(unet, context: Context):
return attention_slice return attention_slice
for name, module in unet.named_modules(): cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
module_name = type(module).__name__ for identifier, module in cross_attention_modules:
if module_name == "CrossAttention": module.identifier = identifier
module.identifier = name try:
module.set_attention_slice_wrangler(attention_slice_wrangler) module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \ module.set_slicing_strategy_getter(
context.get_slicing_strategy(module_identifier)) lambda module: context.get_slicing_strategy(identifier)
)
except AttributeError as e:
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
else:
raise
def remove_attention_function(unet): def remove_attention_function(unet):
# clear wrangler callback cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for name, module in unet.named_modules(): for identifier, module in cross_attention_modules:
module_name = type(module).__name__ try:
if module_name == "CrossAttention": # clear wrangler callback
module.set_attention_slice_wrangler(None) module.set_attention_slice_wrangler(None)
module.set_slicing_strategy_getter(None) module.set_slicing_strategy_getter(None)
except AttributeError as e:
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
else:
raise
def is_attribute_error_about(error: AttributeError, attribute: str):
if hasattr(error, 'name'): # Python 3.10
return error.name == attribute
else: # Python 3.9
return attribute in str(error)
def get_mem_free_total(device):
#only on cuda
if not torch.cuda.is_available():
return None
stats = torch.cuda.memory_stats(device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
return mem_free_total

View File

@ -0,0 +1,95 @@
import math
import PIL
import torch
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode
from ldm.models.diffusion.cross_attention_control import get_cross_attention_modules, CrossAttentionType
class AttentionMapSaver():
def __init__(self, token_ids: range, latents_shape: torch.Size):
self.token_ids = token_ids
self.latents_shape = latents_shape
#self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
self.collated_maps = {}
def clear_maps(self):
self.collated_maps = {}
def add_attention_maps(self, maps: torch.Tensor, key: str):
"""
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
:return: None
"""
key_and_size = f'{key}_{maps.shape[1]}'
# extract desired tokens
maps = maps[:, :, self.token_ids]
# merge attention heads to a single map per token
maps = torch.sum(maps, 0)
# store
if key_and_size not in self.collated_maps:
self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu')
self.collated_maps[key_and_size] += maps.cpu()
def write_maps_to_disk(self, path: str):
pil_image = self.get_stacked_maps_image()
pil_image.save(path, 'PNG')
def get_stacked_maps_image(self) -> PIL.Image:
"""
Scale all collected attention maps to the same size, blend them together and return as an image.
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
"""
num_tokens = len(self.token_ids)
if num_tokens == 0:
return None
latents_height = self.latents_shape[0]
latents_width = self.latents_shape[1]
merged = None
for key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens
# but we want [N, H, W]
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
this_maps_height = int(float(latents_height) * this_scale_factor)
this_maps_width = int(float(latents_width) * this_scale_factor)
# and we need to do some dimension juggling
maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width])
# scale to output size if necessary
if this_scale_factor != 1:
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
# normalize
maps_min = torch.min(maps)
maps_range = torch.max(maps) - maps_min
#print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
maps_normalized = (maps - maps_min) / maps_range
# expand to (-0.1, 1.1) and clamp
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
# merge together, producing a vertical stack
maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width])
if merged is None:
merged = maps_stacked
else:
# screen blend
merged = 1 - (1 - maps_stacked)*(1 - merged)
if merged is None:
return None
merged_bytes = merged.mul(0xff).byte()
return PIL.Image.fromarray(merged_bytes.numpy(), mode='L')

View File

@ -4,6 +4,7 @@ import k_diffusion as K
import torch import torch
from torch import nn from torch import nn
from .cross_attention_map_saving import AttentionMapSaver
from .sampler import Sampler from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
@ -36,6 +37,7 @@ class CFGDenoiser(nn.Module):
self.invokeai_diffuser = InvokeAIDiffuserComponent(model, self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond)) model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
def prepare_to_sample(self, t_enc, **kwargs): def prepare_to_sample(self, t_enc, **kwargs):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None) extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
@ -106,12 +108,12 @@ class KSampler(Sampler):
else: else:
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})') print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
self.sigmas = self.karras_sigmas self.sigmas = self.karras_sigmas
# ALERT: We are completely overriding the sample() method in the base class, which # ALERT: We are completely overriding the sample() method in the base class, which
# means that inpainting will not work. To get this to work we need to be able to # means that inpainting will not work. To get this to work we need to be able to
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way # modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
# in the lstein/k-diffusion branch. # in the lstein/k-diffusion branch.
@torch.no_grad() @torch.no_grad()
def decode( def decode(
self, self,
@ -145,7 +147,7 @@ class KSampler(Sampler):
@torch.no_grad() @torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
return x0 return x0
# Most of these arguments are ignored and are only present for compatibility with # Most of these arguments are ignored and are only present for compatibility with
# other samples # other samples
@torch.no_grad() @torch.no_grad()
@ -158,6 +160,7 @@ class KSampler(Sampler):
callback=None, callback=None,
normals_sequence=None, normals_sequence=None,
img_callback=None, img_callback=None,
attention_maps_callback=None,
quantize_x0=False, quantize_x0=False,
eta=0.0, eta=0.0,
mask=None, mask=None,
@ -171,7 +174,7 @@ class KSampler(Sampler):
log_every_t=100, log_every_t=100,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
extra_conditioning_info=None, extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
threshold = 0, threshold = 0,
perlin = 0, perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -204,6 +207,12 @@ class KSampler(Sampler):
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info) model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
attention_map_token_ids = range(1, extra_conditioning_info.tokens_count_including_eos_bos - 1)
attention_maps_saver = None if attention_maps_callback is None else AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
if attention_maps_callback is not None:
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
extra_args = { extra_args = {
'cond': conditioning, 'cond': conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,
@ -217,6 +226,8 @@ class KSampler(Sampler):
), ),
None, None,
) )
if attention_maps_callback is not None:
attention_maps_callback(attention_maps_saver)
return sampling_result return sampling_result
# this code will support inpainting if and when ksampler API modified or # this code will support inpainting if and when ksampler API modified or
@ -248,7 +259,7 @@ class KSampler(Sampler):
# terrible, confusing names here # terrible, confusing names here
steps = self.ddim_num_steps steps = self.ddim_num_steps
t_enc = self.t_enc t_enc = self.t_enc
# sigmas is a full steps in length, but t_enc might # sigmas is a full steps in length, but t_enc might
# be less. We start in the middle of the sigma array # be less. We start in the middle of the sigma array
# and work our way to the end after t_enc steps. # and work our way to the end after t_enc steps.
@ -280,7 +291,7 @@ class KSampler(Sampler):
return x_T + x return x_T + x
else: else:
return x return x
def prepare_to_sample(self,t_enc,**kwargs): def prepare_to_sample(self,t_enc,**kwargs):
self.t_enc = t_enc self.t_enc = t_enc
self.model_wrap = None self.model_wrap = None

View File

@ -5,8 +5,8 @@ from typing import Callable, Optional, Union
import torch import torch
from ldm.models.diffusion.cross_attention_control import Arguments, \ from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, CrossAttentionType
from ldm.modules.attention import get_mem_free_total from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
class InvokeAIDiffuserComponent: class InvokeAIDiffuserComponent:
@ -21,7 +21,8 @@ class InvokeAIDiffuserComponent:
class ExtraConditioningInfo: class ExtraConditioningInfo:
def __init__(self, cross_attention_control_args: Optional[Arguments]): def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]):
self.tokens_count_including_eos_bos = tokens_count_including_eos_bos
self.cross_attention_control_args = cross_attention_control_args self.cross_attention_control_args = cross_attention_control_args
@property @property
@ -52,7 +53,25 @@ class InvokeAIDiffuserComponent:
self.cross_attention_control_context = None self.cross_attention_control_context = None
remove_cross_attention_control(self.model) remove_cross_attention_control(self.model)
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
if dim is not None:
# sliced tokens attention map saving is not implemented
return
saver.add_attention_maps(slice, key)
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for identifier, module in tokens_cross_attention_modules:
key = ('down' if identifier.startswith('down') else
'up' if identifier.startswith('up') else
'mid')
module.set_attention_slice_calculated_callback(
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key))
def remove_attention_map_saving(self):
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None)
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: Union[torch.Tensor,dict], unconditioning: Union[torch.Tensor,dict],

View File

@ -7,10 +7,9 @@ import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin
from ldm.modules.diffusionmodules.util import checkpoint from ldm.modules.diffusionmodules.util import checkpoint
import psutil
def exists(val): def exists(val):
return val is not None return val is not None
@ -164,9 +163,10 @@ def get_mem_free_total(device):
return mem_free_total return mem_free_total
class CrossAttention(nn.Module): class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__() super().__init__()
InvokeAICrossAttentionMixin.__init__(self)
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -182,118 +182,6 @@ class CrossAttention(nn.Module):
nn.Dropout(dropout) nn.Dropout(dropout)
) )
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.cached_mem_free_total = None
self.attention_slice_wrangler = None
self.slicing_strategy_getter = None
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
'''
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current CrossAttention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
Pass None to use the default attention calculation.
:return:
'''
self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
self.slicing_strategy_getter = getter
def cache_free_memory_count(self, device):
self.cached_mem_free_total = get_mem_free_total(device)
print("free cuda memory: ", self.cached_mem_free_total)
def clear_cached_free_memory_count(self):
self.cached_mem_free_total = None
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
# calculate attention scores
attention_scores = einsum('b i d, b j d -> b i j', q, k)
# calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
return einsum('b i j, b j d -> b i d', attention_slice, v)
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
slicing_strategy_getter = self.slicing_strategy_getter
if slicing_strategy_getter is not None:
(dim, slice_size) = slicing_strategy_getter(self)
if dim is not None:
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
if dim == 0:
return self.einsum_op_slice_dim0(q, k, v, slice_size)
elif dim == 1:
return self.einsum_op_slice_dim1(q, k, v, slice_size)
# fallback for when there is no saved strategy, or saved strategy does not slice
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda':
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v)
if q.device.type == 'mps':
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
@ -305,7 +193,11 @@ class CrossAttention(nn.Module):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = self.get_attention_mem_efficient(q, k, v) # don't apply scale twice
cached_scale = self.scale
self.scale = 1
r = self.get_invokeai_attention_mem_efficient(q, k, v)
self.scale = cached_scale
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h) hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
return self.to_out(hidden_states) return self.to_out(hidden_states)