mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Save and display per-token attention maps (#1866)
* attention maps saving to /tmp * tidy up diffusers branch backporting of cross attention refactoring * base64-encoding the attention maps image for generationResult * cleanup/refactor conditioning.py * attention maps and tokens being sent to web UI * attention maps: restrict count to actual token count and improve robustness * add argument type hint to image_to_dataURL function Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Co-authored-by: damian <git@damianstewart.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
parent
55132f6463
commit
786b8878d6
@ -18,9 +18,11 @@ from PIL.Image import Image as ImageType
|
||||
from uuid import uuid4
|
||||
from threading import Event
|
||||
|
||||
from ldm.generate import Generate
|
||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
||||
from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend
|
||||
from ldm.invoke.generator.inpaint import infill_methods
|
||||
|
||||
from backend.modules.parameters import parameters_to_command
|
||||
@ -39,7 +41,7 @@ if not os.path.isabs(args.outdir):
|
||||
|
||||
|
||||
class InvokeAIWebServer:
|
||||
def __init__(self, generate, gfpgan, codeformer, esrgan) -> None:
|
||||
def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None:
|
||||
self.host = args.host
|
||||
self.port = args.port
|
||||
|
||||
@ -905,16 +907,13 @@ class InvokeAIWebServer:
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if generation_parameters["progress_latents"]:
|
||||
image = self.generate.sample_to_lowres_estimated_image(sample)
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
img_base64 = "data:image/png;base64," + base64.b64encode(
|
||||
buffered.getvalue()
|
||||
).decode("UTF-8")
|
||||
img_base64 = image_to_dataURL(image)
|
||||
self.socketio.emit(
|
||||
"intermediateResult",
|
||||
{
|
||||
@ -932,7 +931,7 @@ class InvokeAIWebServer:
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
def image_done(image, seed, first_seed):
|
||||
def image_done(image, seed, first_seed, attention_maps_image=None):
|
||||
if self.canceled.is_set():
|
||||
raise CanceledException
|
||||
|
||||
@ -1094,6 +1093,12 @@ class InvokeAIWebServer:
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
|
||||
tokens = None if type(parsed_prompt) is Blend else \
|
||||
get_tokens_for_prompt(self.generate.model, parsed_prompt)
|
||||
attention_maps_image_base64_url = None if attention_maps_image is None \
|
||||
else image_to_dataURL(attention_maps_image)
|
||||
|
||||
self.socketio.emit(
|
||||
"generationResult",
|
||||
{
|
||||
@ -1106,6 +1111,8 @@ class InvokeAIWebServer:
|
||||
"height": height,
|
||||
"boundingBox": original_bounding_box,
|
||||
"generationMode": generation_parameters["generation_mode"],
|
||||
"attentionMaps": attention_maps_image_base64_url,
|
||||
"tokens": tokens,
|
||||
},
|
||||
)
|
||||
eventlet.sleep(0)
|
||||
@ -1117,7 +1124,7 @@ class InvokeAIWebServer:
|
||||
self.generate.prompt2image(
|
||||
**generation_parameters,
|
||||
step_callback=image_progress,
|
||||
image_callback=image_done,
|
||||
image_callback=image_done
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
@ -1564,6 +1571,19 @@ def dataURL_to_image(dataURL: str) -> ImageType:
|
||||
)
|
||||
return image
|
||||
|
||||
"""
|
||||
Converts an image into a base64 image dataURL.
|
||||
"""
|
||||
|
||||
def image_to_dataURL(image: ImageType) -> str:
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
image_base64 = "data:image/png;base64," + base64.b64encode(
|
||||
buffered.getvalue()
|
||||
).decode("UTF-8")
|
||||
return image_base64
|
||||
|
||||
|
||||
|
||||
"""
|
||||
Converts a base64 image dataURL into bytes.
|
||||
|
@ -20,6 +20,8 @@ import cv2
|
||||
import skimage
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import ldm.invoke.conditioning
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
from PIL import Image, ImageOps
|
||||
from torch import nn
|
||||
@ -40,7 +42,7 @@ from ldm.invoke.model_cache import ModelCache
|
||||
from ldm.invoke.seamless import configure_model_padding
|
||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||
from ldm.invoke.concepts_lib import Concepts
|
||||
|
||||
|
||||
def fix_func(orig):
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
def new_func(*args, **kw):
|
||||
@ -235,7 +237,7 @@ class Generate:
|
||||
except Exception:
|
||||
print('** An error was encountered while installing the safety checker:')
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
def prompt2png(self, prompt, outdir, **kwargs):
|
||||
"""
|
||||
Takes a prompt and an output directory, writes out the requested number
|
||||
@ -329,7 +331,7 @@ class Generate:
|
||||
infill_method = infill_methods[0], # The infill method to use
|
||||
force_outpaint: bool = False,
|
||||
enable_image_debugging = False,
|
||||
|
||||
|
||||
**args,
|
||||
): # eat up additional cruft
|
||||
"""
|
||||
@ -372,7 +374,7 @@ class Generate:
|
||||
def process_image(image,seed):
|
||||
image.save(f{'images/seed.png'})
|
||||
|
||||
The code used to save images to a directory can be found in ldm/invoke/pngwriter.py.
|
||||
The code used to save images to a directory can be found in ldm/invoke/pngwriter.py.
|
||||
It contains code to create the requested output directory, select a unique informative
|
||||
name for each image, and write the prompt into the PNG metadata.
|
||||
"""
|
||||
@ -455,7 +457,7 @@ class Generate:
|
||||
try:
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
skip_normalize=skip_normalize,
|
||||
skip_normalize_legacy_blend=skip_normalize,
|
||||
log_tokens =self.log_tokenization
|
||||
)
|
||||
|
||||
@ -589,7 +591,7 @@ class Generate:
|
||||
seed = opt.seed or args.seed
|
||||
if seed is None or seed < 0:
|
||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
|
||||
|
||||
prompt = opt.prompt or args.prompt or ''
|
||||
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')
|
||||
|
||||
@ -607,8 +609,8 @@ class Generate:
|
||||
# todo: cross-attention control
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
skip_normalize=opt.skip_normalize,
|
||||
log_tokens =opt.log_tokenization
|
||||
skip_normalize_legacy_blend=opt.skip_normalize,
|
||||
log_tokens =ldm.invoke.conditioning.log_tokenization
|
||||
)
|
||||
|
||||
if tool in ('gfpgan','codeformer','upscale'):
|
||||
@ -641,7 +643,7 @@ class Generate:
|
||||
|
||||
opt.seed = seed
|
||||
opt.prompt = prompt
|
||||
|
||||
|
||||
if len(extend_instructions) > 0:
|
||||
restorer = Outcrop(image,self,)
|
||||
return restorer.process (
|
||||
@ -683,7 +685,7 @@ class Generate:
|
||||
image_callback = callback,
|
||||
prefix = prefix
|
||||
)
|
||||
|
||||
|
||||
elif tool is None:
|
||||
print(f'* please provide at least one postprocessing option, such as -G or -U')
|
||||
return None
|
||||
@ -706,13 +708,13 @@ class Generate:
|
||||
|
||||
if embiggen is not None:
|
||||
return self._make_embiggen()
|
||||
|
||||
|
||||
if inpainting_model_in_use:
|
||||
return self._make_omnibus()
|
||||
|
||||
if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
|
||||
return self._make_inpaint()
|
||||
|
||||
|
||||
if init_image is not None:
|
||||
return self._make_img2img()
|
||||
|
||||
@ -743,7 +745,7 @@ class Generate:
|
||||
if self._has_transparency(image):
|
||||
self._transparency_check_and_warning(image, mask, force_outpaint)
|
||||
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||
|
||||
|
||||
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||
self.size_matters = False
|
||||
@ -759,7 +761,7 @@ class Generate:
|
||||
|
||||
if init_mask and invert_mask:
|
||||
init_mask = ImageOps.invert(init_mask)
|
||||
|
||||
|
||||
return init_image,init_mask
|
||||
|
||||
# lots o' repeated code here! Turn into a make_func()
|
||||
@ -818,7 +820,7 @@ class Generate:
|
||||
self.set_model(self.model_name)
|
||||
|
||||
def set_model(self,model_name):
|
||||
"""
|
||||
"""
|
||||
Given the name of a model defined in models.yaml, will load and initialize it
|
||||
and return the model object. Previously-used models will be cached.
|
||||
"""
|
||||
@ -830,7 +832,7 @@ class Generate:
|
||||
if not cache.valid_model(model_name):
|
||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
return self.model
|
||||
|
||||
|
||||
cache.print_vram_usage()
|
||||
|
||||
# have to get rid of all references to model in order
|
||||
@ -839,7 +841,7 @@ class Generate:
|
||||
self.sampler = None
|
||||
self.generators = {}
|
||||
gc.collect()
|
||||
|
||||
|
||||
model_data = cache.get_model(model_name)
|
||||
if model_data is None: # restore previous
|
||||
model_data = cache.get_model(self.model_name)
|
||||
@ -852,7 +854,7 @@ class Generate:
|
||||
|
||||
# uncache generators so they pick up new models
|
||||
self.generators = {}
|
||||
|
||||
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
if self.embedding_path is not None:
|
||||
self.model.embedding_manager.load(
|
||||
@ -901,7 +903,7 @@ class Generate:
|
||||
image_callback = None,
|
||||
prefix = None,
|
||||
):
|
||||
|
||||
|
||||
for r in image_list:
|
||||
image, seed = r
|
||||
try:
|
||||
@ -911,7 +913,7 @@ class Generate:
|
||||
if self.gfpgan is None:
|
||||
print('>> GFPGAN not found. Face restoration is disabled.')
|
||||
else:
|
||||
image = self.gfpgan.process(image, strength, seed)
|
||||
image = self.gfpgan.process(image, strength, seed)
|
||||
if facetool == 'codeformer':
|
||||
if self.codeformer is None:
|
||||
print('>> CodeFormer not found. Face restoration is disabled.')
|
||||
|
@ -8,6 +8,7 @@ import time
|
||||
import traceback
|
||||
import yaml
|
||||
|
||||
from ldm.generate import Generate
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
from ldm.invoke.readline import get_completer, Completer
|
||||
@ -27,7 +28,7 @@ def main():
|
||||
"""Initialize command-line parsers and the diffusion model"""
|
||||
global infile
|
||||
print('* Initializing, be patient...')
|
||||
|
||||
|
||||
opt = Args()
|
||||
args = opt.parse_args()
|
||||
if not args:
|
||||
@ -47,7 +48,7 @@ def main():
|
||||
# alert - setting globals here
|
||||
Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.'))
|
||||
Globals.try_patchmatch = args.patchmatch
|
||||
|
||||
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# loading here to avoid long delays on startup
|
||||
@ -281,7 +282,7 @@ def main_loop(gen, opt):
|
||||
prefix = file_writer.unique_prefix()
|
||||
step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None
|
||||
|
||||
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None):
|
||||
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None, attention_maps_image=None):
|
||||
# note the seed is the seed of the current image
|
||||
# the first_seed is the original seed that noise is added to
|
||||
# when the -v switch is used to generate variations
|
||||
@ -341,8 +342,8 @@ def main_loop(gen, opt):
|
||||
filename,
|
||||
tool,
|
||||
formatted_dream_prompt,
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
if (not postprocessed) or opt.save_original:
|
||||
# only append to results if we didn't overwrite an earlier output
|
||||
results.append([path, formatted_dream_prompt])
|
||||
@ -432,7 +433,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
||||
add_embedding_terms(gen, completer)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
|
||||
elif command.startswith('!models'):
|
||||
gen.model_cache.print_models()
|
||||
completer.add_history(command)
|
||||
@ -533,7 +534,7 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
||||
|
||||
completer.complete_extensions(('.yaml','.yml'))
|
||||
completer.linebuffer = 'configs/stable-diffusion/v1-inference.yaml'
|
||||
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
new_config['config'] = input('Configuration file for this model: ')
|
||||
@ -564,7 +565,7 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
||||
print('** Please enter a valid integer between 64 and 2048')
|
||||
|
||||
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||
|
||||
|
||||
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
|
||||
completer.add_model(model_name)
|
||||
|
||||
@ -577,14 +578,14 @@ def del_config(model_name:str, gen, opt, completer):
|
||||
gen.model_cache.commit(opt.conf)
|
||||
print(f'** {model_name} deleted')
|
||||
completer.del_model(model_name)
|
||||
|
||||
|
||||
def edit_config(model_name:str, gen, opt, completer):
|
||||
config = gen.model_cache.config
|
||||
|
||||
|
||||
if model_name not in config:
|
||||
print(f'** Unknown model {model_name}')
|
||||
return
|
||||
|
||||
|
||||
print(f'\n>> Editing model {model_name} from configuration file {opt.conf}')
|
||||
|
||||
conf = config[model_name]
|
||||
@ -597,10 +598,10 @@ def edit_config(model_name:str, gen, opt, completer):
|
||||
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||
completer.complete_extensions(None)
|
||||
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
|
||||
|
||||
|
||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
|
||||
current_model = gen.model_name
|
||||
|
||||
|
||||
op = 'modify' if clobber else 'import'
|
||||
print('\n>> New configuration:')
|
||||
if make_default:
|
||||
@ -623,7 +624,7 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False, mak
|
||||
gen.model_cache.set_default_model(model_name)
|
||||
|
||||
gen.model_cache.commit(conf_path)
|
||||
|
||||
|
||||
do_switch = input(f'Keep model loaded? [y]')
|
||||
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||
pass
|
||||
@ -653,7 +654,7 @@ def do_postprocess (gen, opt, callback):
|
||||
opt.prompt = opt.new_prompt
|
||||
else:
|
||||
opt.prompt = None
|
||||
|
||||
|
||||
if os.path.dirname(file_path) == '': #basename given
|
||||
file_path = os.path.join(opt.outdir,file_path)
|
||||
|
||||
@ -718,7 +719,7 @@ def add_postprocessing_to_metadata(opt,original_file,new_file,tool,command):
|
||||
)
|
||||
meta['image']['postprocessing'] = pp
|
||||
write_metadata(new_file,meta)
|
||||
|
||||
|
||||
def prepare_image_metadata(
|
||||
opt,
|
||||
prefix,
|
||||
@ -789,28 +790,28 @@ def get_next_command(infile=None) -> str: # command string
|
||||
print(f'#{command}')
|
||||
return command
|
||||
|
||||
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
|
||||
def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan):
|
||||
print('\n* --web was specified, starting web server...')
|
||||
from backend.invoke_ai_web_server import InvokeAIWebServer
|
||||
# Change working directory to the stable-diffusion directory
|
||||
os.chdir(
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
)
|
||||
|
||||
|
||||
invoke_ai_web_server = InvokeAIWebServer(generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan)
|
||||
|
||||
try:
|
||||
invoke_ai_web_server.run()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
def add_embedding_terms(gen,completer):
|
||||
'''
|
||||
Called after setting the model, updates the autocompleter with
|
||||
any terms loaded by the embedding manager.
|
||||
'''
|
||||
completer.add_embedding_terms(gen.model.embedding_manager.list_terms())
|
||||
|
||||
|
||||
def split_variations(variations_string) -> list:
|
||||
# shotgun parsing, woo
|
||||
parts = []
|
||||
@ -867,7 +868,7 @@ def make_step_callback(gen, opt, prefix):
|
||||
image = gen.sample_to_image(img)
|
||||
image.save(filename,'PNG')
|
||||
return callback
|
||||
|
||||
|
||||
def retrieve_dream_command(opt,command,completer):
|
||||
'''
|
||||
Given a full or partial path to a previously-generated image file,
|
||||
@ -875,7 +876,7 @@ def retrieve_dream_command(opt,command,completer):
|
||||
and pop it into the readline buffer (linux, Mac), or print out a comment
|
||||
for cut-and-paste (windows)
|
||||
|
||||
Given a wildcard path to a folder with image png files,
|
||||
Given a wildcard path to a folder with image png files,
|
||||
will retrieve and format the dream command used to generate the images,
|
||||
and save them to a file commands.txt for further processing
|
||||
'''
|
||||
@ -911,7 +912,7 @@ def write_commands(opt, file_path:str, outfilepath:str):
|
||||
except ValueError:
|
||||
print(f'## "{basename}": unacceptable pattern')
|
||||
return
|
||||
|
||||
|
||||
commands = []
|
||||
cmd = None
|
||||
for path in paths:
|
||||
@ -940,7 +941,7 @@ def emergency_model_reconfigure():
|
||||
print(' After reconfiguration is done, please relaunch invoke.py. ')
|
||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
||||
print('configure_invokeai is launching....\n')
|
||||
|
||||
|
||||
sys.argv = ['configure_invokeai','--interactive']
|
||||
import configure_invokeai
|
||||
configure_invokeai.main()
|
||||
|
@ -7,20 +7,46 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
|
||||
|
||||
'''
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
|
||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment
|
||||
from ..models.diffusion import cross_attention_control
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
|
||||
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
prompt, negative_prompt = get_prompt_structure(prompt_string,
|
||||
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
||||
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> (
|
||||
Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
"""
|
||||
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
|
||||
"""
|
||||
prompt, negative_prompt = _parse_prompt_string(prompt_string,
|
||||
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
||||
return prompt, negative_prompt
|
||||
|
||||
|
||||
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]:
|
||||
text_fragments = [x.text if type(x) is Fragment else
|
||||
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
|
||||
str(x))
|
||||
for x in parsed_prompt.children]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
||||
return tokens
|
||||
|
||||
|
||||
def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]:
|
||||
# Extract Unconditioned Words From Prompt
|
||||
unconditioned_words = ''
|
||||
unconditional_regex = r'\[(.*?)\]'
|
||||
@ -39,7 +65,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
pp = PromptParser()
|
||||
|
||||
parsed_prompt: Union[FlattenedPrompt, Blend] = None
|
||||
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned)
|
||||
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned, skip_normalize_legacy_blend)
|
||||
if legacy_blend is not None:
|
||||
parsed_prompt = legacy_blend
|
||||
else:
|
||||
@ -47,129 +73,150 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
|
||||
|
||||
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
|
||||
return parsed_prompt, parsed_negative_prompt
|
||||
|
||||
|
||||
def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt,
|
||||
model, log_tokens=False) \
|
||||
-> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]:
|
||||
"""
|
||||
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
|
||||
"""
|
||||
|
||||
if log_tokens:
|
||||
print(f">> Parsed prompt to {parsed_prompt}")
|
||||
print(f">> Parsed negative prompt to {parsed_negative_prompt}")
|
||||
|
||||
conditioning = None
|
||||
cac_args:cross_attention_control.Arguments = None
|
||||
cac_args: cross_attention_control.Arguments = None
|
||||
|
||||
if type(parsed_prompt) is Blend:
|
||||
blend: Blend = parsed_prompt
|
||||
embeddings_to_blend = None
|
||||
for i,flattened_prompt in enumerate(blend.prompts):
|
||||
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
flattened_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
|
||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||
(embeddings_to_blend, this_embedding))
|
||||
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||
blend.weights,
|
||||
normalize=blend.normalize_weights)
|
||||
else:
|
||||
flattened_prompt: FlattenedPrompt = parsed_prompt
|
||||
wants_cross_attention_control = type(flattened_prompt) is not Blend \
|
||||
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
|
||||
if wants_cross_attention_control:
|
||||
original_prompt = FlattenedPrompt()
|
||||
edited_prompt = FlattenedPrompt()
|
||||
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
|
||||
original_token_count = 0
|
||||
edited_token_count = 0
|
||||
edit_options = []
|
||||
edit_opcodes = []
|
||||
# beginning of sequence
|
||||
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
|
||||
edit_options.append(None)
|
||||
original_token_count += 1
|
||||
edited_token_count += 1
|
||||
for fragment in flattened_prompt.children:
|
||||
if type(fragment) is CrossAttentionControlSubstitute:
|
||||
original_prompt.append(fragment.original)
|
||||
edited_prompt.append(fragment.edited)
|
||||
conditioning = _get_conditioning_for_blend(model, parsed_prompt, log_tokens)
|
||||
elif type(parsed_prompt) is FlattenedPrompt:
|
||||
if parsed_prompt.wants_cross_attention_control:
|
||||
conditioning, cac_args = _get_conditioning_for_cross_attention_control(model, parsed_prompt, log_tokens)
|
||||
|
||||
to_replace_token_count = get_tokens_length(model, fragment.original)
|
||||
replacement_token_count = get_tokens_length(model, fragment.edited)
|
||||
edit_opcodes.append(('replace',
|
||||
original_token_count, original_token_count + to_replace_token_count,
|
||||
edited_token_count, edited_token_count + replacement_token_count
|
||||
))
|
||||
original_token_count += to_replace_token_count
|
||||
edited_token_count += replacement_token_count
|
||||
edit_options.append(fragment.options)
|
||||
#elif type(fragment) is CrossAttentionControlAppend:
|
||||
# edited_prompt.append(fragment.fragment)
|
||||
else:
|
||||
# regular fragment
|
||||
original_prompt.append(fragment)
|
||||
edited_prompt.append(fragment)
|
||||
|
||||
count = get_tokens_length(model, [fragment])
|
||||
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
|
||||
edit_options.append(None)
|
||||
original_token_count += count
|
||||
edited_token_count += count
|
||||
# end of sequence
|
||||
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
|
||||
edit_options.append(None)
|
||||
original_token_count += 1
|
||||
edited_token_count += 1
|
||||
|
||||
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
original_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(.swap originals)")
|
||||
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
||||
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
||||
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
||||
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
||||
# token 'smiling' in the inactive 'cat' edit.
|
||||
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
||||
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
edited_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(.swap replacements)")
|
||||
|
||||
conditioning = original_embeddings
|
||||
edited_conditioning = edited_embeddings
|
||||
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
||||
cac_args = cross_attention_control.Arguments(
|
||||
edited_conditioning = edited_conditioning,
|
||||
edit_opcodes = edit_opcodes,
|
||||
edit_options = edit_options
|
||||
)
|
||||
else:
|
||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
flattened_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(prompt)")
|
||||
conditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||
parsed_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(prompt)")
|
||||
else:
|
||||
raise ValueError(f"parsed_prompt is '{type(parsed_prompt)}' which is not a supported prompt type")
|
||||
|
||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
||||
parsed_negative_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(unconditioning)")
|
||||
unconditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||
parsed_negative_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(unconditioning)")
|
||||
if isinstance(conditioning, dict):
|
||||
# hybrid conditioning is in play
|
||||
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
|
||||
unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning)
|
||||
if cac_args is not None:
|
||||
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
||||
print(
|
||||
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
||||
cac_args = None
|
||||
|
||||
eos_token_index = 1
|
||||
if type(parsed_prompt) is not Blend:
|
||||
tokens = get_tokens_for_prompt(model, parsed_prompt)
|
||||
eos_token_index = len(tokens)+1
|
||||
return (
|
||||
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=eos_token_index + 1,
|
||||
cross_attention_control_args=cac_args
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_token_edit_opcodes(original_tokens, edited_tokens):
|
||||
original_tokens = original_tokens.cpu().numpy()[0]
|
||||
edited_tokens = edited_tokens.cpu().numpy()[0]
|
||||
def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True):
|
||||
original_prompt = FlattenedPrompt()
|
||||
edited_prompt = FlattenedPrompt()
|
||||
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
|
||||
original_token_count = 0
|
||||
edited_token_count = 0
|
||||
edit_options = []
|
||||
edit_opcodes = []
|
||||
# beginning of sequence
|
||||
edit_opcodes.append(
|
||||
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
|
||||
edit_options.append(None)
|
||||
original_token_count += 1
|
||||
edited_token_count += 1
|
||||
for fragment in prompt.children:
|
||||
if type(fragment) is CrossAttentionControlSubstitute:
|
||||
original_prompt.append(fragment.original)
|
||||
edited_prompt.append(fragment.edited)
|
||||
|
||||
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
|
||||
to_replace_token_count = _get_tokens_length(model, fragment.original)
|
||||
replacement_token_count = _get_tokens_length(model, fragment.edited)
|
||||
edit_opcodes.append(('replace',
|
||||
original_token_count, original_token_count + to_replace_token_count,
|
||||
edited_token_count, edited_token_count + replacement_token_count
|
||||
))
|
||||
original_token_count += to_replace_token_count
|
||||
edited_token_count += replacement_token_count
|
||||
edit_options.append(fragment.options)
|
||||
# elif type(fragment) is CrossAttentionControlAppend:
|
||||
# edited_prompt.append(fragment.fragment)
|
||||
else:
|
||||
# regular fragment
|
||||
original_prompt.append(fragment)
|
||||
edited_prompt.append(fragment)
|
||||
|
||||
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None):
|
||||
count = _get_tokens_length(model, [fragment])
|
||||
edit_opcodes.append(('equal', original_token_count, original_token_count + count, edited_token_count,
|
||||
edited_token_count + count))
|
||||
edit_options.append(None)
|
||||
original_token_count += count
|
||||
edited_token_count += count
|
||||
# end of sequence
|
||||
edit_opcodes.append(
|
||||
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
|
||||
edit_options.append(None)
|
||||
original_token_count += 1
|
||||
edited_token_count += 1
|
||||
original_embeddings, original_tokens = _get_embeddings_and_tokens_for_prompt(model,
|
||||
original_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(.swap originals)")
|
||||
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
||||
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
||||
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
||||
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
||||
# token 'smiling' in the inactive 'cat' edit.
|
||||
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
||||
edited_embeddings, edited_tokens = _get_embeddings_and_tokens_for_prompt(model,
|
||||
edited_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label="(.swap replacements)")
|
||||
conditioning = original_embeddings
|
||||
edited_conditioning = edited_embeddings
|
||||
# print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
||||
cac_args = cross_attention_control.Arguments(
|
||||
edited_conditioning=edited_conditioning,
|
||||
edit_opcodes=edit_opcodes,
|
||||
edit_options=edit_options
|
||||
)
|
||||
return conditioning, cac_args
|
||||
|
||||
|
||||
def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
|
||||
embeddings_to_blend = None
|
||||
for i, flattened_prompt in enumerate(blend.prompts):
|
||||
this_embedding, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||
flattened_prompt,
|
||||
log_tokens=log_tokens,
|
||||
log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
|
||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||
(embeddings_to_blend, this_embedding))
|
||||
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||
blend.weights,
|
||||
normalize=blend.normalize_weights)
|
||||
return conditioning
|
||||
|
||||
|
||||
def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False,
|
||||
log_display_label: str = None):
|
||||
if type(flattened_prompt) is not FlattenedPrompt:
|
||||
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
||||
fragments = [x.text for x in flattened_prompt.children]
|
||||
@ -181,12 +228,14 @@ def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: Fl
|
||||
|
||||
return embeddings, tokens
|
||||
|
||||
def get_tokens_length(model, fragments: list[Fragment]):
|
||||
|
||||
def _get_tokens_length(model, fragments: list[Fragment]):
|
||||
fragment_texts = [x.text for x in fragments]
|
||||
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
||||
return sum([len(x) for x in tokens])
|
||||
|
||||
def flatten_hybrid_conditioning(uncond, cond):
|
||||
|
||||
def _flatten_hybrid_conditioning(uncond, cond):
|
||||
'''
|
||||
This handles the choice between a conditional conditioning
|
||||
that is a tensor (used by cross attention) vs one that has additional
|
||||
@ -205,4 +254,29 @@ def flatten_hybrid_conditioning(uncond, cond):
|
||||
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
|
||||
return uncond, cond_flattened
|
||||
|
||||
|
||||
|
||||
def log_tokenization(text, model, display_label=None):
|
||||
""" shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
"""
|
||||
|
||||
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace('</w>', ' ')
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if i < model.cond_stage_model.max_length:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f">> Tokens Discarded ({totalTokens - usedTokens}):\n{discarded}\x1b[0m"
|
||||
)
|
||||
|
@ -14,6 +14,7 @@ import cv2 as cv
|
||||
from einops import rearrange, repeat
|
||||
from pytorch_lightning import seed_everything
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from ldm.util import rand_perlin_2d
|
||||
|
||||
downsampling = 8
|
||||
@ -51,9 +52,12 @@ class Generator():
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
attention_maps_callback = None,
|
||||
**kwargs):
|
||||
scope = choose_autocast(self.precision)
|
||||
self.safety_checker = safety_checker
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler = sampler,
|
||||
@ -63,6 +67,7 @@ class Generator():
|
||||
step_callback = step_callback,
|
||||
threshold = threshold,
|
||||
perlin = perlin,
|
||||
attention_maps_callback = attention_maps_callback,
|
||||
**kwargs
|
||||
)
|
||||
results = []
|
||||
@ -98,12 +103,12 @@ class Generator():
|
||||
results.append([image, seed])
|
||||
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, first_seed=first_seed)
|
||||
image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_images[-1])
|
||||
|
||||
seed = self.new_seed()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def sample_to_image(self,samples)->Image.Image:
|
||||
"""
|
||||
Given samples returned from a sampler, converts
|
||||
@ -166,12 +171,12 @@ class Generator():
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
|
||||
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(self,samples):
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
@ -219,11 +224,11 @@ class Generator():
|
||||
(txt2img) or from the latent image (img2img, inpaint)
|
||||
"""
|
||||
raise NotImplementedError("get_noise() must be implemented in a descendent class")
|
||||
|
||||
|
||||
def get_perlin_noise(self,width,height):
|
||||
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
||||
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
||||
|
||||
|
||||
def new_seed(self):
|
||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
return self.seed
|
||||
@ -325,4 +330,4 @@ class Generator():
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
image.save(filepath,'PNG')
|
||||
|
||||
|
||||
|
||||
|
@ -14,7 +14,9 @@ class Txt2Img(Generator):
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
||||
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
|
||||
attention_maps_callback=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
@ -33,7 +35,7 @@ class Txt2Img(Generator):
|
||||
|
||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||
self.model.model.to(self.model.device)
|
||||
|
||||
|
||||
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
||||
|
||||
samples, _ = sampler.sample(
|
||||
@ -49,6 +51,7 @@ class Txt2Img(Generator):
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback,
|
||||
threshold = threshold,
|
||||
attention_maps_callback = attention_maps_callback,
|
||||
)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
|
@ -3,7 +3,7 @@ from typing import Union, Optional
|
||||
import re
|
||||
import pyparsing as pp
|
||||
'''
|
||||
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
|
||||
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
|
||||
weighted subprompts.
|
||||
|
||||
Useful class exports:
|
||||
@ -69,6 +69,12 @@ class FlattenedPrompt():
|
||||
return len(self.children) == 0 or \
|
||||
(len(self.children) == 1 and len(self.children[0].text) == 0)
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return any(
|
||||
[issubclass(type(x), CrossAttentionControlledFragment) for x in self.children]
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"FlattenedPrompt:{self.children}"
|
||||
def __eq__(self, other):
|
||||
@ -240,6 +246,12 @@ class Blend():
|
||||
self.weights = weights
|
||||
self.normalize_weights = normalize_weights
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
# blends cannot cross-attention control
|
||||
return False
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
|
||||
def __eq__(self, other):
|
||||
@ -277,8 +289,8 @@ class PromptParser():
|
||||
|
||||
return self.flatten(root[0])
|
||||
|
||||
def parse_legacy_blend(self, text: str) -> Optional[Blend]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False)
|
||||
def parse_legacy_blend(self, text: str, skip_normalize: bool) -> Optional[Blend]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
@ -287,7 +299,7 @@ class PromptParser():
|
||||
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||
|
||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
|
||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
|
||||
|
||||
|
||||
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
|
||||
@ -641,27 +653,3 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||
|
||||
|
||||
# shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
def log_tokenization(text, model, display_label=None):
|
||||
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace('</w>', ' ')
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if i < model.cond_stage_model.max_length:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
||||
)
|
||||
|
@ -53,7 +53,6 @@ COMMANDS = (
|
||||
'--codeformer_fidelity','-cf',
|
||||
'--upscale','-U',
|
||||
'-save_orig','--save_original',
|
||||
'--skip_normalize','-x',
|
||||
'--log_tokenization','-t',
|
||||
'--hires_fix',
|
||||
'--inpaint_replace','-r',
|
||||
@ -117,19 +116,19 @@ class Completer(object):
|
||||
# extensions defined, so go directly into path completion mode
|
||||
if self.extensions is not None:
|
||||
self.matches = self._path_completions(text, state, self.extensions)
|
||||
|
||||
|
||||
# looking for an image file
|
||||
elif re.search(path_regexp,buffer):
|
||||
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
|
||||
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
|
||||
|
||||
# looking for a seed
|
||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||
self.matches= self._seed_completions(text,state)
|
||||
|
||||
elif re.search('<[\w-]*$',buffer):
|
||||
elif re.search('<[\w-]*$',buffer):
|
||||
self.matches= self._concept_completions(text,state)
|
||||
|
||||
|
||||
# looking for a model
|
||||
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state)
|
||||
@ -227,7 +226,7 @@ class Completer(object):
|
||||
if h_len < 1:
|
||||
print('<empty history>')
|
||||
return
|
||||
|
||||
|
||||
for i in range(0,h_len):
|
||||
line = self.get_history_item(i+1)
|
||||
if match and match not in line:
|
||||
@ -367,7 +366,7 @@ class DummyCompleter(Completer):
|
||||
def __init__(self,options):
|
||||
super().__init__(options)
|
||||
self.history = list()
|
||||
|
||||
|
||||
def add_history(self,line):
|
||||
self.history.append(line)
|
||||
|
||||
|
@ -1,12 +1,14 @@
|
||||
import enum
|
||||
from typing import Optional
|
||||
import math
|
||||
from typing import Optional, Callable
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
|
||||
class Arguments:
|
||||
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
||||
"""
|
||||
@ -63,9 +65,13 @@ class Context:
|
||||
self.clear_requests(cleanup=True)
|
||||
|
||||
def register_cross_attention_modules(self, model):
|
||||
for name,module in get_attention_modules(model, CrossAttentionType.SELF):
|
||||
for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||
if name in self.self_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
self.self_cross_attention_module_identifiers.append(name)
|
||||
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
if name in self.tokens_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
|
||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
@ -166,6 +172,135 @@ class Context:
|
||||
map_dict[offset] = slice.to('cpu')
|
||||
|
||||
|
||||
|
||||
class InvokeAICrossAttentionMixin:
|
||||
"""
|
||||
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||
and dymamic slicing strategy selection.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
self.attention_slice_wrangler = None
|
||||
self.slicing_strategy_getter = None
|
||||
self.attention_slice_calculated_callback = None
|
||||
|
||||
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
||||
'''
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||
`suggested_attention_slice` is the default-calculated attention slice
|
||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
'''
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
||||
self.slicing_strategy_getter = getter
|
||||
|
||||
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
||||
self.attention_slice_calculated_callback = callback
|
||||
|
||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
# calculate attention slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
attention_slice_wrangler = self.attention_slice_wrangler
|
||||
if attention_slice_wrangler is not None:
|
||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
if self.attention_slice_calculated_callback is not None:
|
||||
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
||||
|
||||
hidden_states = torch.bmm(attention_slice, value)
|
||||
return hidden_states
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||
slicing_strategy_getter = self.slicing_strategy_getter
|
||||
if slicing_strategy_getter is not None:
|
||||
(dim, slice_size) = slicing_strategy_getter(self)
|
||||
if dim is not None:
|
||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||
if dim == 0:
|
||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||
elif dim == 1:
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
|
||||
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == 'mps' or q.device.type == 'cpu':
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
|
||||
|
||||
def remove_cross_attention_control(model):
|
||||
remove_attention_function(model)
|
||||
|
||||
@ -187,7 +322,7 @@ def setup_cross_attention_control(model, context: Context):
|
||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||
mask = torch.zeros(max_length)
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.zeros(max_length, dtype=torch.long)
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
@ -201,10 +336,23 @@ def setup_cross_attention_control(model, context: Context):
|
||||
inject_attention_function(model, context)
|
||||
|
||||
|
||||
def get_attention_modules(model, which: CrossAttentionType):
|
||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
cross_attention_class: type = InvokeAICrossAttentionMixin
|
||||
# cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||
return [(name,module) for name, module in model.named_modules() if
|
||||
type(module).__name__ == "CrossAttention" and which_attn in name]
|
||||
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
||||
isinstance(module, cross_attention_class) and which_attn in name]
|
||||
cross_attention_modules_in_model_count = len(attention_module_tuples)
|
||||
expected_count = 16
|
||||
if cross_attention_modules_in_model_count != expected_count:
|
||||
# non-fatal error but .swap() won't work.
|
||||
print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " +
|
||||
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " +
|
||||
f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " +
|
||||
f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " +
|
||||
f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " +
|
||||
f"work properly until it is fixed.")
|
||||
return attention_module_tuples
|
||||
|
||||
|
||||
def inject_attention_function(unet, context: Context):
|
||||
@ -244,19 +392,52 @@ def inject_attention_function(unet, context: Context):
|
||||
|
||||
return attention_slice
|
||||
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.identifier = name
|
||||
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
module.identifier = identifier
|
||||
try:
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \
|
||||
context.get_slicing_strategy(module_identifier))
|
||||
module.set_slicing_strategy_getter(
|
||||
lambda module: context.get_slicing_strategy(identifier)
|
||||
)
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def remove_attention_function(unet):
|
||||
# clear wrangler callback
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
try:
|
||||
# clear wrangler callback
|
||||
module.set_attention_slice_wrangler(None)
|
||||
module.set_slicing_strategy_getter(None)
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def is_attribute_error_about(error: AttributeError, attribute: str):
|
||||
if hasattr(error, 'name'): # Python 3.10
|
||||
return error.name == attribute
|
||||
else: # Python 3.9
|
||||
return attribute in str(error)
|
||||
|
||||
|
||||
|
||||
def get_mem_free_total(device):
|
||||
#only on cuda
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
stats = torch.cuda.memory_stats(device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
return mem_free_total
|
||||
|
||||
|
95
ldm/models/diffusion/cross_attention_map_saving.py
Normal file
95
ldm/models/diffusion/cross_attention_map_saving.py
Normal file
@ -0,0 +1,95 @@
|
||||
import math
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import get_cross_attention_modules, CrossAttentionType
|
||||
|
||||
|
||||
class AttentionMapSaver():
|
||||
|
||||
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
||||
self.token_ids = token_ids
|
||||
self.latents_shape = latents_shape
|
||||
#self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||
self.collated_maps = {}
|
||||
|
||||
def clear_maps(self):
|
||||
self.collated_maps = {}
|
||||
|
||||
def add_attention_maps(self, maps: torch.Tensor, key: str):
|
||||
"""
|
||||
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
|
||||
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
|
||||
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
|
||||
:return: None
|
||||
"""
|
||||
key_and_size = f'{key}_{maps.shape[1]}'
|
||||
|
||||
# extract desired tokens
|
||||
maps = maps[:, :, self.token_ids]
|
||||
|
||||
# merge attention heads to a single map per token
|
||||
maps = torch.sum(maps, 0)
|
||||
|
||||
# store
|
||||
if key_and_size not in self.collated_maps:
|
||||
self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu')
|
||||
self.collated_maps[key_and_size] += maps.cpu()
|
||||
|
||||
def write_maps_to_disk(self, path: str):
|
||||
pil_image = self.get_stacked_maps_image()
|
||||
pil_image.save(path, 'PNG')
|
||||
|
||||
def get_stacked_maps_image(self) -> PIL.Image:
|
||||
"""
|
||||
Scale all collected attention maps to the same size, blend them together and return as an image.
|
||||
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
||||
"""
|
||||
num_tokens = len(self.token_ids)
|
||||
if num_tokens == 0:
|
||||
return None
|
||||
|
||||
latents_height = self.latents_shape[0]
|
||||
latents_width = self.latents_shape[1]
|
||||
|
||||
merged = None
|
||||
|
||||
for key, maps in self.collated_maps.items():
|
||||
|
||||
# maps has shape [(H*W), N] for N tokens
|
||||
# but we want [N, H, W]
|
||||
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||
this_maps_height = int(float(latents_height) * this_scale_factor)
|
||||
this_maps_width = int(float(latents_width) * this_scale_factor)
|
||||
# and we need to do some dimension juggling
|
||||
maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width])
|
||||
|
||||
# scale to output size if necessary
|
||||
if this_scale_factor != 1:
|
||||
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
||||
|
||||
# normalize
|
||||
maps_min = torch.min(maps)
|
||||
maps_range = torch.max(maps) - maps_min
|
||||
#print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
||||
maps_normalized = (maps - maps_min) / maps_range
|
||||
# expand to (-0.1, 1.1) and clamp
|
||||
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
||||
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
||||
|
||||
# merge together, producing a vertical stack
|
||||
maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width])
|
||||
|
||||
if merged is None:
|
||||
merged = maps_stacked
|
||||
else:
|
||||
# screen blend
|
||||
merged = 1 - (1 - maps_stacked)*(1 - merged)
|
||||
|
||||
if merged is None:
|
||||
return None
|
||||
|
||||
merged_bytes = merged.mul(0xff).byte()
|
||||
return PIL.Image.fromarray(merged_bytes.numpy(), mode='L')
|
@ -4,6 +4,7 @@ import k_diffusion as K
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
from .sampler import Sampler
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
@ -36,6 +37,7 @@ class CFGDenoiser(nn.Module):
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
@ -106,12 +108,12 @@ class KSampler(Sampler):
|
||||
else:
|
||||
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
|
||||
self.sigmas = self.karras_sigmas
|
||||
|
||||
|
||||
# ALERT: We are completely overriding the sample() method in the base class, which
|
||||
# means that inpainting will not work. To get this to work we need to be able to
|
||||
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
|
||||
# in the lstein/k-diffusion branch.
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
@ -145,7 +147,7 @@ class KSampler(Sampler):
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
return x0
|
||||
|
||||
|
||||
# Most of these arguments are ignored and are only present for compatibility with
|
||||
# other samples
|
||||
@torch.no_grad()
|
||||
@ -158,6 +160,7 @@ class KSampler(Sampler):
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
attention_maps_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
@ -171,7 +174,7 @@ class KSampler(Sampler):
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
extra_conditioning_info=None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
||||
threshold = 0,
|
||||
perlin = 0,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
@ -204,6 +207,12 @@ class KSampler(Sampler):
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||
|
||||
attention_map_token_ids = range(1, extra_conditioning_info.tokens_count_including_eos_bos - 1)
|
||||
attention_maps_saver = None if attention_maps_callback is None else AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
||||
if attention_maps_callback is not None:
|
||||
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
|
||||
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
@ -217,6 +226,8 @@ class KSampler(Sampler):
|
||||
),
|
||||
None,
|
||||
)
|
||||
if attention_maps_callback is not None:
|
||||
attention_maps_callback(attention_maps_saver)
|
||||
return sampling_result
|
||||
|
||||
# this code will support inpainting if and when ksampler API modified or
|
||||
@ -248,7 +259,7 @@ class KSampler(Sampler):
|
||||
# terrible, confusing names here
|
||||
steps = self.ddim_num_steps
|
||||
t_enc = self.t_enc
|
||||
|
||||
|
||||
# sigmas is a full steps in length, but t_enc might
|
||||
# be less. We start in the middle of the sigma array
|
||||
# and work our way to the end after t_enc steps.
|
||||
@ -280,7 +291,7 @@ class KSampler(Sampler):
|
||||
return x_T + x
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def prepare_to_sample(self,t_enc,**kwargs):
|
||||
self.t_enc = t_enc
|
||||
self.model_wrap = None
|
||||
|
@ -5,8 +5,8 @@ from typing import Callable, Optional, Union
|
||||
import torch
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||
remove_cross_attention_control, setup_cross_attention_control, Context
|
||||
from ldm.modules.attention import get_mem_free_total
|
||||
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, CrossAttentionType
|
||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
@ -21,7 +21,8 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
|
||||
class ExtraConditioningInfo:
|
||||
def __init__(self, cross_attention_control_args: Optional[Arguments]):
|
||||
def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]):
|
||||
self.tokens_count_including_eos_bos = tokens_count_including_eos_bos
|
||||
self.cross_attention_control_args = cross_attention_control_args
|
||||
|
||||
@property
|
||||
@ -52,7 +53,25 @@ class InvokeAIDiffuserComponent:
|
||||
self.cross_attention_control_context = None
|
||||
remove_cross_attention_control(self.model)
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
if dim is not None:
|
||||
# sliced tokens attention map saving is not implemented
|
||||
return
|
||||
saver.add_attention_maps(slice, key)
|
||||
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for identifier, module in tokens_cross_attention_modules:
|
||||
key = ('down' if identifier.startswith('down') else
|
||||
'up' if identifier.startswith('up') else
|
||||
'mid')
|
||||
module.set_attention_slice_calculated_callback(
|
||||
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key))
|
||||
|
||||
def remove_attention_map_saving(self):
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for _, module in tokens_cross_attention_modules:
|
||||
module.set_attention_slice_calculated_callback(None)
|
||||
|
||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||
unconditioning: Union[torch.Tensor,dict],
|
||||
|
@ -7,10 +7,9 @@ import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin
|
||||
from ldm.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
import psutil
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
@ -164,9 +163,10 @@ def get_mem_free_total(device):
|
||||
return mem_free_total
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
InvokeAICrossAttentionMixin.__init__(self)
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
@ -182,118 +182,6 @@ class CrossAttention(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
self.cached_mem_free_total = None
|
||||
self.attention_slice_wrangler = None
|
||||
self.slicing_strategy_getter = None
|
||||
|
||||
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
||||
'''
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||
`suggested_attention_slice` is the default-calculated attention slice
|
||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
'''
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
||||
self.slicing_strategy_getter = getter
|
||||
|
||||
def cache_free_memory_count(self, device):
|
||||
self.cached_mem_free_total = get_mem_free_total(device)
|
||||
print("free cuda memory: ", self.cached_mem_free_total)
|
||||
|
||||
def clear_cached_free_memory_count(self):
|
||||
self.cached_mem_free_total = None
|
||||
|
||||
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
attention_scores = einsum('b i d, b j d -> b i j', q, k)
|
||||
# calculate attention slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
attention_slice_wrangler = self.attention_slice_wrangler
|
||||
if attention_slice_wrangler is not None:
|
||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
return einsum('b i j, b j d -> b i d', attention_slice, v)
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||
slicing_strategy_getter = self.slicing_strategy_getter
|
||||
if slicing_strategy_getter is not None:
|
||||
(dim, slice_size) = slicing_strategy_getter(self)
|
||||
if dim is not None:
|
||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||
if dim == 0:
|
||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||
elif dim == 1:
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
|
||||
def get_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == 'mps':
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
@ -305,7 +193,11 @@ class CrossAttention(nn.Module):
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
r = self.get_attention_mem_efficient(q, k, v)
|
||||
# don't apply scale twice
|
||||
cached_scale = self.scale
|
||||
self.scale = 1
|
||||
r = self.get_invokeai_attention_mem_efficient(q, k, v)
|
||||
self.scale = cached_scale
|
||||
|
||||
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(hidden_states)
|
||||
|
Loading…
Reference in New Issue
Block a user