mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Save and display per-token attention maps (#1866)
* attention maps saving to /tmp * tidy up diffusers branch backporting of cross attention refactoring * base64-encoding the attention maps image for generationResult * cleanup/refactor conditioning.py * attention maps and tokens being sent to web UI * attention maps: restrict count to actual token count and improve robustness * add argument type hint to image_to_dataURL function Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Co-authored-by: damian <git@damianstewart.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
parent
55132f6463
commit
786b8878d6
@ -18,9 +18,11 @@ from PIL.Image import Image as ImageType
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from threading import Event
|
from threading import Event
|
||||||
|
|
||||||
|
from ldm.generate import Generate
|
||||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||||
|
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure
|
||||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||||
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend
|
||||||
from ldm.invoke.generator.inpaint import infill_methods
|
from ldm.invoke.generator.inpaint import infill_methods
|
||||||
|
|
||||||
from backend.modules.parameters import parameters_to_command
|
from backend.modules.parameters import parameters_to_command
|
||||||
@ -39,7 +41,7 @@ if not os.path.isabs(args.outdir):
|
|||||||
|
|
||||||
|
|
||||||
class InvokeAIWebServer:
|
class InvokeAIWebServer:
|
||||||
def __init__(self, generate, gfpgan, codeformer, esrgan) -> None:
|
def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None:
|
||||||
self.host = args.host
|
self.host = args.host
|
||||||
self.port = args.port
|
self.port = args.port
|
||||||
|
|
||||||
@ -905,16 +907,13 @@ class InvokeAIWebServer:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if generation_parameters["progress_latents"]:
|
if generation_parameters["progress_latents"]:
|
||||||
image = self.generate.sample_to_lowres_estimated_image(sample)
|
image = self.generate.sample_to_lowres_estimated_image(sample)
|
||||||
(width, height) = image.size
|
(width, height) = image.size
|
||||||
width *= 8
|
width *= 8
|
||||||
height *= 8
|
height *= 8
|
||||||
buffered = io.BytesIO()
|
img_base64 = image_to_dataURL(image)
|
||||||
image.save(buffered, format="PNG")
|
|
||||||
img_base64 = "data:image/png;base64," + base64.b64encode(
|
|
||||||
buffered.getvalue()
|
|
||||||
).decode("UTF-8")
|
|
||||||
self.socketio.emit(
|
self.socketio.emit(
|
||||||
"intermediateResult",
|
"intermediateResult",
|
||||||
{
|
{
|
||||||
@ -932,7 +931,7 @@ class InvokeAIWebServer:
|
|||||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
|
|
||||||
def image_done(image, seed, first_seed):
|
def image_done(image, seed, first_seed, attention_maps_image=None):
|
||||||
if self.canceled.is_set():
|
if self.canceled.is_set():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
@ -1094,6 +1093,12 @@ class InvokeAIWebServer:
|
|||||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
|
|
||||||
|
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
|
||||||
|
tokens = None if type(parsed_prompt) is Blend else \
|
||||||
|
get_tokens_for_prompt(self.generate.model, parsed_prompt)
|
||||||
|
attention_maps_image_base64_url = None if attention_maps_image is None \
|
||||||
|
else image_to_dataURL(attention_maps_image)
|
||||||
|
|
||||||
self.socketio.emit(
|
self.socketio.emit(
|
||||||
"generationResult",
|
"generationResult",
|
||||||
{
|
{
|
||||||
@ -1106,6 +1111,8 @@ class InvokeAIWebServer:
|
|||||||
"height": height,
|
"height": height,
|
||||||
"boundingBox": original_bounding_box,
|
"boundingBox": original_bounding_box,
|
||||||
"generationMode": generation_parameters["generation_mode"],
|
"generationMode": generation_parameters["generation_mode"],
|
||||||
|
"attentionMaps": attention_maps_image_base64_url,
|
||||||
|
"tokens": tokens,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
@ -1117,7 +1124,7 @@ class InvokeAIWebServer:
|
|||||||
self.generate.prompt2image(
|
self.generate.prompt2image(
|
||||||
**generation_parameters,
|
**generation_parameters,
|
||||||
step_callback=image_progress,
|
step_callback=image_progress,
|
||||||
image_callback=image_done,
|
image_callback=image_done
|
||||||
)
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@ -1564,6 +1571,19 @@ def dataURL_to_image(dataURL: str) -> ImageType:
|
|||||||
)
|
)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
"""
|
||||||
|
Converts an image into a base64 image dataURL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def image_to_dataURL(image: ImageType) -> str:
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
image.save(buffered, format="PNG")
|
||||||
|
image_base64 = "data:image/png;base64," + base64.b64encode(
|
||||||
|
buffered.getvalue()
|
||||||
|
).decode("UTF-8")
|
||||||
|
return image_base64
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Converts a base64 image dataURL into bytes.
|
Converts a base64 image dataURL into bytes.
|
||||||
|
@ -20,6 +20,8 @@ import cv2
|
|||||||
import skimage
|
import skimage
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
import ldm.invoke.conditioning
|
||||||
from ldm.invoke.generator.base import downsampling
|
from ldm.invoke.generator.base import downsampling
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -40,7 +42,7 @@ from ldm.invoke.model_cache import ModelCache
|
|||||||
from ldm.invoke.seamless import configure_model_padding
|
from ldm.invoke.seamless import configure_model_padding
|
||||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||||
from ldm.invoke.concepts_lib import Concepts
|
from ldm.invoke.concepts_lib import Concepts
|
||||||
|
|
||||||
def fix_func(orig):
|
def fix_func(orig):
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||||
def new_func(*args, **kw):
|
def new_func(*args, **kw):
|
||||||
@ -235,7 +237,7 @@ class Generate:
|
|||||||
except Exception:
|
except Exception:
|
||||||
print('** An error was encountered while installing the safety checker:')
|
print('** An error was encountered while installing the safety checker:')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
def prompt2png(self, prompt, outdir, **kwargs):
|
def prompt2png(self, prompt, outdir, **kwargs):
|
||||||
"""
|
"""
|
||||||
Takes a prompt and an output directory, writes out the requested number
|
Takes a prompt and an output directory, writes out the requested number
|
||||||
@ -329,7 +331,7 @@ class Generate:
|
|||||||
infill_method = infill_methods[0], # The infill method to use
|
infill_method = infill_methods[0], # The infill method to use
|
||||||
force_outpaint: bool = False,
|
force_outpaint: bool = False,
|
||||||
enable_image_debugging = False,
|
enable_image_debugging = False,
|
||||||
|
|
||||||
**args,
|
**args,
|
||||||
): # eat up additional cruft
|
): # eat up additional cruft
|
||||||
"""
|
"""
|
||||||
@ -372,7 +374,7 @@ class Generate:
|
|||||||
def process_image(image,seed):
|
def process_image(image,seed):
|
||||||
image.save(f{'images/seed.png'})
|
image.save(f{'images/seed.png'})
|
||||||
|
|
||||||
The code used to save images to a directory can be found in ldm/invoke/pngwriter.py.
|
The code used to save images to a directory can be found in ldm/invoke/pngwriter.py.
|
||||||
It contains code to create the requested output directory, select a unique informative
|
It contains code to create the requested output directory, select a unique informative
|
||||||
name for each image, and write the prompt into the PNG metadata.
|
name for each image, and write the prompt into the PNG metadata.
|
||||||
"""
|
"""
|
||||||
@ -455,7 +457,7 @@ class Generate:
|
|||||||
try:
|
try:
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||||
prompt, model =self.model,
|
prompt, model =self.model,
|
||||||
skip_normalize=skip_normalize,
|
skip_normalize_legacy_blend=skip_normalize,
|
||||||
log_tokens =self.log_tokenization
|
log_tokens =self.log_tokenization
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -589,7 +591,7 @@ class Generate:
|
|||||||
seed = opt.seed or args.seed
|
seed = opt.seed or args.seed
|
||||||
if seed is None or seed < 0:
|
if seed is None or seed < 0:
|
||||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||||
|
|
||||||
prompt = opt.prompt or args.prompt or ''
|
prompt = opt.prompt or args.prompt or ''
|
||||||
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')
|
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')
|
||||||
|
|
||||||
@ -607,8 +609,8 @@ class Generate:
|
|||||||
# todo: cross-attention control
|
# todo: cross-attention control
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||||
prompt, model =self.model,
|
prompt, model =self.model,
|
||||||
skip_normalize=opt.skip_normalize,
|
skip_normalize_legacy_blend=opt.skip_normalize,
|
||||||
log_tokens =opt.log_tokenization
|
log_tokens =ldm.invoke.conditioning.log_tokenization
|
||||||
)
|
)
|
||||||
|
|
||||||
if tool in ('gfpgan','codeformer','upscale'):
|
if tool in ('gfpgan','codeformer','upscale'):
|
||||||
@ -641,7 +643,7 @@ class Generate:
|
|||||||
|
|
||||||
opt.seed = seed
|
opt.seed = seed
|
||||||
opt.prompt = prompt
|
opt.prompt = prompt
|
||||||
|
|
||||||
if len(extend_instructions) > 0:
|
if len(extend_instructions) > 0:
|
||||||
restorer = Outcrop(image,self,)
|
restorer = Outcrop(image,self,)
|
||||||
return restorer.process (
|
return restorer.process (
|
||||||
@ -683,7 +685,7 @@ class Generate:
|
|||||||
image_callback = callback,
|
image_callback = callback,
|
||||||
prefix = prefix
|
prefix = prefix
|
||||||
)
|
)
|
||||||
|
|
||||||
elif tool is None:
|
elif tool is None:
|
||||||
print(f'* please provide at least one postprocessing option, such as -G or -U')
|
print(f'* please provide at least one postprocessing option, such as -G or -U')
|
||||||
return None
|
return None
|
||||||
@ -706,13 +708,13 @@ class Generate:
|
|||||||
|
|
||||||
if embiggen is not None:
|
if embiggen is not None:
|
||||||
return self._make_embiggen()
|
return self._make_embiggen()
|
||||||
|
|
||||||
if inpainting_model_in_use:
|
if inpainting_model_in_use:
|
||||||
return self._make_omnibus()
|
return self._make_omnibus()
|
||||||
|
|
||||||
if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
|
if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
|
||||||
return self._make_inpaint()
|
return self._make_inpaint()
|
||||||
|
|
||||||
if init_image is not None:
|
if init_image is not None:
|
||||||
return self._make_img2img()
|
return self._make_img2img()
|
||||||
|
|
||||||
@ -743,7 +745,7 @@ class Generate:
|
|||||||
if self._has_transparency(image):
|
if self._has_transparency(image):
|
||||||
self._transparency_check_and_warning(image, mask, force_outpaint)
|
self._transparency_check_and_warning(image, mask, force_outpaint)
|
||||||
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||||
|
|
||||||
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
||||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||||
self.size_matters = False
|
self.size_matters = False
|
||||||
@ -759,7 +761,7 @@ class Generate:
|
|||||||
|
|
||||||
if init_mask and invert_mask:
|
if init_mask and invert_mask:
|
||||||
init_mask = ImageOps.invert(init_mask)
|
init_mask = ImageOps.invert(init_mask)
|
||||||
|
|
||||||
return init_image,init_mask
|
return init_image,init_mask
|
||||||
|
|
||||||
# lots o' repeated code here! Turn into a make_func()
|
# lots o' repeated code here! Turn into a make_func()
|
||||||
@ -818,7 +820,7 @@ class Generate:
|
|||||||
self.set_model(self.model_name)
|
self.set_model(self.model_name)
|
||||||
|
|
||||||
def set_model(self,model_name):
|
def set_model(self,model_name):
|
||||||
"""
|
"""
|
||||||
Given the name of a model defined in models.yaml, will load and initialize it
|
Given the name of a model defined in models.yaml, will load and initialize it
|
||||||
and return the model object. Previously-used models will be cached.
|
and return the model object. Previously-used models will be cached.
|
||||||
"""
|
"""
|
||||||
@ -830,7 +832,7 @@ class Generate:
|
|||||||
if not cache.valid_model(model_name):
|
if not cache.valid_model(model_name):
|
||||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
cache.print_vram_usage()
|
cache.print_vram_usage()
|
||||||
|
|
||||||
# have to get rid of all references to model in order
|
# have to get rid of all references to model in order
|
||||||
@ -839,7 +841,7 @@ class Generate:
|
|||||||
self.sampler = None
|
self.sampler = None
|
||||||
self.generators = {}
|
self.generators = {}
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
model_data = cache.get_model(model_name)
|
model_data = cache.get_model(model_name)
|
||||||
if model_data is None: # restore previous
|
if model_data is None: # restore previous
|
||||||
model_data = cache.get_model(self.model_name)
|
model_data = cache.get_model(self.model_name)
|
||||||
@ -852,7 +854,7 @@ class Generate:
|
|||||||
|
|
||||||
# uncache generators so they pick up new models
|
# uncache generators so they pick up new models
|
||||||
self.generators = {}
|
self.generators = {}
|
||||||
|
|
||||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||||
if self.embedding_path is not None:
|
if self.embedding_path is not None:
|
||||||
self.model.embedding_manager.load(
|
self.model.embedding_manager.load(
|
||||||
@ -901,7 +903,7 @@ class Generate:
|
|||||||
image_callback = None,
|
image_callback = None,
|
||||||
prefix = None,
|
prefix = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
for r in image_list:
|
for r in image_list:
|
||||||
image, seed = r
|
image, seed = r
|
||||||
try:
|
try:
|
||||||
@ -911,7 +913,7 @@ class Generate:
|
|||||||
if self.gfpgan is None:
|
if self.gfpgan is None:
|
||||||
print('>> GFPGAN not found. Face restoration is disabled.')
|
print('>> GFPGAN not found. Face restoration is disabled.')
|
||||||
else:
|
else:
|
||||||
image = self.gfpgan.process(image, strength, seed)
|
image = self.gfpgan.process(image, strength, seed)
|
||||||
if facetool == 'codeformer':
|
if facetool == 'codeformer':
|
||||||
if self.codeformer is None:
|
if self.codeformer is None:
|
||||||
print('>> CodeFormer not found. Face restoration is disabled.')
|
print('>> CodeFormer not found. Face restoration is disabled.')
|
||||||
|
@ -8,6 +8,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from ldm.generate import Generate
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.invoke.prompt_parser import PromptParser
|
from ldm.invoke.prompt_parser import PromptParser
|
||||||
from ldm.invoke.readline import get_completer, Completer
|
from ldm.invoke.readline import get_completer, Completer
|
||||||
@ -27,7 +28,7 @@ def main():
|
|||||||
"""Initialize command-line parsers and the diffusion model"""
|
"""Initialize command-line parsers and the diffusion model"""
|
||||||
global infile
|
global infile
|
||||||
print('* Initializing, be patient...')
|
print('* Initializing, be patient...')
|
||||||
|
|
||||||
opt = Args()
|
opt = Args()
|
||||||
args = opt.parse_args()
|
args = opt.parse_args()
|
||||||
if not args:
|
if not args:
|
||||||
@ -47,7 +48,7 @@ def main():
|
|||||||
# alert - setting globals here
|
# alert - setting globals here
|
||||||
Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.'))
|
Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.'))
|
||||||
Globals.try_patchmatch = args.patchmatch
|
Globals.try_patchmatch = args.patchmatch
|
||||||
|
|
||||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||||
|
|
||||||
# loading here to avoid long delays on startup
|
# loading here to avoid long delays on startup
|
||||||
@ -281,7 +282,7 @@ def main_loop(gen, opt):
|
|||||||
prefix = file_writer.unique_prefix()
|
prefix = file_writer.unique_prefix()
|
||||||
step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None
|
step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None
|
||||||
|
|
||||||
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None):
|
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None, attention_maps_image=None):
|
||||||
# note the seed is the seed of the current image
|
# note the seed is the seed of the current image
|
||||||
# the first_seed is the original seed that noise is added to
|
# the first_seed is the original seed that noise is added to
|
||||||
# when the -v switch is used to generate variations
|
# when the -v switch is used to generate variations
|
||||||
@ -341,8 +342,8 @@ def main_loop(gen, opt):
|
|||||||
filename,
|
filename,
|
||||||
tool,
|
tool,
|
||||||
formatted_dream_prompt,
|
formatted_dream_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (not postprocessed) or opt.save_original:
|
if (not postprocessed) or opt.save_original:
|
||||||
# only append to results if we didn't overwrite an earlier output
|
# only append to results if we didn't overwrite an earlier output
|
||||||
results.append([path, formatted_dream_prompt])
|
results.append([path, formatted_dream_prompt])
|
||||||
@ -432,7 +433,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
|||||||
add_embedding_terms(gen, completer)
|
add_embedding_terms(gen, completer)
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
operation = None
|
operation = None
|
||||||
|
|
||||||
elif command.startswith('!models'):
|
elif command.startswith('!models'):
|
||||||
gen.model_cache.print_models()
|
gen.model_cache.print_models()
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
@ -533,7 +534,7 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
|||||||
|
|
||||||
completer.complete_extensions(('.yaml','.yml'))
|
completer.complete_extensions(('.yaml','.yml'))
|
||||||
completer.linebuffer = 'configs/stable-diffusion/v1-inference.yaml'
|
completer.linebuffer = 'configs/stable-diffusion/v1-inference.yaml'
|
||||||
|
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
new_config['config'] = input('Configuration file for this model: ')
|
new_config['config'] = input('Configuration file for this model: ')
|
||||||
@ -564,7 +565,7 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
|
|||||||
print('** Please enter a valid integer between 64 and 2048')
|
print('** Please enter a valid integer between 64 and 2048')
|
||||||
|
|
||||||
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||||
|
|
||||||
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
|
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
|
||||||
completer.add_model(model_name)
|
completer.add_model(model_name)
|
||||||
|
|
||||||
@ -577,14 +578,14 @@ def del_config(model_name:str, gen, opt, completer):
|
|||||||
gen.model_cache.commit(opt.conf)
|
gen.model_cache.commit(opt.conf)
|
||||||
print(f'** {model_name} deleted')
|
print(f'** {model_name} deleted')
|
||||||
completer.del_model(model_name)
|
completer.del_model(model_name)
|
||||||
|
|
||||||
def edit_config(model_name:str, gen, opt, completer):
|
def edit_config(model_name:str, gen, opt, completer):
|
||||||
config = gen.model_cache.config
|
config = gen.model_cache.config
|
||||||
|
|
||||||
if model_name not in config:
|
if model_name not in config:
|
||||||
print(f'** Unknown model {model_name}')
|
print(f'** Unknown model {model_name}')
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f'\n>> Editing model {model_name} from configuration file {opt.conf}')
|
print(f'\n>> Editing model {model_name} from configuration file {opt.conf}')
|
||||||
|
|
||||||
conf = config[model_name]
|
conf = config[model_name]
|
||||||
@ -597,10 +598,10 @@ def edit_config(model_name:str, gen, opt, completer):
|
|||||||
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
||||||
completer.complete_extensions(None)
|
completer.complete_extensions(None)
|
||||||
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
|
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
|
||||||
|
|
||||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
|
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
|
||||||
current_model = gen.model_name
|
current_model = gen.model_name
|
||||||
|
|
||||||
op = 'modify' if clobber else 'import'
|
op = 'modify' if clobber else 'import'
|
||||||
print('\n>> New configuration:')
|
print('\n>> New configuration:')
|
||||||
if make_default:
|
if make_default:
|
||||||
@ -623,7 +624,7 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False, mak
|
|||||||
gen.model_cache.set_default_model(model_name)
|
gen.model_cache.set_default_model(model_name)
|
||||||
|
|
||||||
gen.model_cache.commit(conf_path)
|
gen.model_cache.commit(conf_path)
|
||||||
|
|
||||||
do_switch = input(f'Keep model loaded? [y]')
|
do_switch = input(f'Keep model loaded? [y]')
|
||||||
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||||
pass
|
pass
|
||||||
@ -653,7 +654,7 @@ def do_postprocess (gen, opt, callback):
|
|||||||
opt.prompt = opt.new_prompt
|
opt.prompt = opt.new_prompt
|
||||||
else:
|
else:
|
||||||
opt.prompt = None
|
opt.prompt = None
|
||||||
|
|
||||||
if os.path.dirname(file_path) == '': #basename given
|
if os.path.dirname(file_path) == '': #basename given
|
||||||
file_path = os.path.join(opt.outdir,file_path)
|
file_path = os.path.join(opt.outdir,file_path)
|
||||||
|
|
||||||
@ -718,7 +719,7 @@ def add_postprocessing_to_metadata(opt,original_file,new_file,tool,command):
|
|||||||
)
|
)
|
||||||
meta['image']['postprocessing'] = pp
|
meta['image']['postprocessing'] = pp
|
||||||
write_metadata(new_file,meta)
|
write_metadata(new_file,meta)
|
||||||
|
|
||||||
def prepare_image_metadata(
|
def prepare_image_metadata(
|
||||||
opt,
|
opt,
|
||||||
prefix,
|
prefix,
|
||||||
@ -789,28 +790,28 @@ def get_next_command(infile=None) -> str: # command string
|
|||||||
print(f'#{command}')
|
print(f'#{command}')
|
||||||
return command
|
return command
|
||||||
|
|
||||||
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
|
def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan):
|
||||||
print('\n* --web was specified, starting web server...')
|
print('\n* --web was specified, starting web server...')
|
||||||
from backend.invoke_ai_web_server import InvokeAIWebServer
|
from backend.invoke_ai_web_server import InvokeAIWebServer
|
||||||
# Change working directory to the stable-diffusion directory
|
# Change working directory to the stable-diffusion directory
|
||||||
os.chdir(
|
os.chdir(
|
||||||
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
)
|
)
|
||||||
|
|
||||||
invoke_ai_web_server = InvokeAIWebServer(generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan)
|
invoke_ai_web_server = InvokeAIWebServer(generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
invoke_ai_web_server.run()
|
invoke_ai_web_server.run()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def add_embedding_terms(gen,completer):
|
def add_embedding_terms(gen,completer):
|
||||||
'''
|
'''
|
||||||
Called after setting the model, updates the autocompleter with
|
Called after setting the model, updates the autocompleter with
|
||||||
any terms loaded by the embedding manager.
|
any terms loaded by the embedding manager.
|
||||||
'''
|
'''
|
||||||
completer.add_embedding_terms(gen.model.embedding_manager.list_terms())
|
completer.add_embedding_terms(gen.model.embedding_manager.list_terms())
|
||||||
|
|
||||||
def split_variations(variations_string) -> list:
|
def split_variations(variations_string) -> list:
|
||||||
# shotgun parsing, woo
|
# shotgun parsing, woo
|
||||||
parts = []
|
parts = []
|
||||||
@ -867,7 +868,7 @@ def make_step_callback(gen, opt, prefix):
|
|||||||
image = gen.sample_to_image(img)
|
image = gen.sample_to_image(img)
|
||||||
image.save(filename,'PNG')
|
image.save(filename,'PNG')
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
def retrieve_dream_command(opt,command,completer):
|
def retrieve_dream_command(opt,command,completer):
|
||||||
'''
|
'''
|
||||||
Given a full or partial path to a previously-generated image file,
|
Given a full or partial path to a previously-generated image file,
|
||||||
@ -875,7 +876,7 @@ def retrieve_dream_command(opt,command,completer):
|
|||||||
and pop it into the readline buffer (linux, Mac), or print out a comment
|
and pop it into the readline buffer (linux, Mac), or print out a comment
|
||||||
for cut-and-paste (windows)
|
for cut-and-paste (windows)
|
||||||
|
|
||||||
Given a wildcard path to a folder with image png files,
|
Given a wildcard path to a folder with image png files,
|
||||||
will retrieve and format the dream command used to generate the images,
|
will retrieve and format the dream command used to generate the images,
|
||||||
and save them to a file commands.txt for further processing
|
and save them to a file commands.txt for further processing
|
||||||
'''
|
'''
|
||||||
@ -911,7 +912,7 @@ def write_commands(opt, file_path:str, outfilepath:str):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
print(f'## "{basename}": unacceptable pattern')
|
print(f'## "{basename}": unacceptable pattern')
|
||||||
return
|
return
|
||||||
|
|
||||||
commands = []
|
commands = []
|
||||||
cmd = None
|
cmd = None
|
||||||
for path in paths:
|
for path in paths:
|
||||||
@ -940,7 +941,7 @@ def emergency_model_reconfigure():
|
|||||||
print(' After reconfiguration is done, please relaunch invoke.py. ')
|
print(' After reconfiguration is done, please relaunch invoke.py. ')
|
||||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
||||||
print('configure_invokeai is launching....\n')
|
print('configure_invokeai is launching....\n')
|
||||||
|
|
||||||
sys.argv = ['configure_invokeai','--interactive']
|
sys.argv = ['configure_invokeai','--interactive']
|
||||||
import configure_invokeai
|
import configure_invokeai
|
||||||
configure_invokeai.main()
|
configure_invokeai.main()
|
||||||
|
@ -7,20 +7,46 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
import re
|
import re
|
||||||
from difflib import SequenceMatcher
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
||||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
|
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment
|
||||||
from ..models.diffusion import cross_attention_control
|
from ..models.diffusion import cross_attention_control
|
||||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
|
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
||||||
|
prompt, negative_prompt = get_prompt_structure(prompt_string,
|
||||||
|
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
||||||
|
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
|
||||||
|
|
||||||
|
return conditioning
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> (
|
||||||
|
Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||||
|
"""
|
||||||
|
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
|
||||||
|
"""
|
||||||
|
prompt, negative_prompt = _parse_prompt_string(prompt_string,
|
||||||
|
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
||||||
|
return prompt, negative_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]:
|
||||||
|
text_fragments = [x.text if type(x) is Fragment else
|
||||||
|
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
|
||||||
|
str(x))
|
||||||
|
for x in parsed_prompt.children]
|
||||||
|
text = " ".join(text_fragments)
|
||||||
|
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]:
|
||||||
# Extract Unconditioned Words From Prompt
|
# Extract Unconditioned Words From Prompt
|
||||||
unconditioned_words = ''
|
unconditioned_words = ''
|
||||||
unconditional_regex = r'\[(.*?)\]'
|
unconditional_regex = r'\[(.*?)\]'
|
||||||
@ -39,7 +65,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
pp = PromptParser()
|
pp = PromptParser()
|
||||||
|
|
||||||
parsed_prompt: Union[FlattenedPrompt, Blend] = None
|
parsed_prompt: Union[FlattenedPrompt, Blend] = None
|
||||||
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned)
|
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned, skip_normalize_legacy_blend)
|
||||||
if legacy_blend is not None:
|
if legacy_blend is not None:
|
||||||
parsed_prompt = legacy_blend
|
parsed_prompt = legacy_blend
|
||||||
else:
|
else:
|
||||||
@ -47,129 +73,150 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
|
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
|
||||||
|
|
||||||
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
|
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
|
||||||
|
return parsed_prompt, parsed_negative_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt,
|
||||||
|
model, log_tokens=False) \
|
||||||
|
-> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]:
|
||||||
|
"""
|
||||||
|
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
|
||||||
|
"""
|
||||||
|
|
||||||
if log_tokens:
|
if log_tokens:
|
||||||
print(f">> Parsed prompt to {parsed_prompt}")
|
print(f">> Parsed prompt to {parsed_prompt}")
|
||||||
print(f">> Parsed negative prompt to {parsed_negative_prompt}")
|
print(f">> Parsed negative prompt to {parsed_negative_prompt}")
|
||||||
|
|
||||||
conditioning = None
|
conditioning = None
|
||||||
cac_args:cross_attention_control.Arguments = None
|
cac_args: cross_attention_control.Arguments = None
|
||||||
|
|
||||||
if type(parsed_prompt) is Blend:
|
if type(parsed_prompt) is Blend:
|
||||||
blend: Blend = parsed_prompt
|
conditioning = _get_conditioning_for_blend(model, parsed_prompt, log_tokens)
|
||||||
embeddings_to_blend = None
|
elif type(parsed_prompt) is FlattenedPrompt:
|
||||||
for i,flattened_prompt in enumerate(blend.prompts):
|
if parsed_prompt.wants_cross_attention_control:
|
||||||
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
conditioning, cac_args = _get_conditioning_for_cross_attention_control(model, parsed_prompt, log_tokens)
|
||||||
flattened_prompt,
|
|
||||||
log_tokens=log_tokens,
|
|
||||||
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
|
|
||||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
|
||||||
(embeddings_to_blend, this_embedding))
|
|
||||||
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
|
||||||
blend.weights,
|
|
||||||
normalize=blend.normalize_weights)
|
|
||||||
else:
|
|
||||||
flattened_prompt: FlattenedPrompt = parsed_prompt
|
|
||||||
wants_cross_attention_control = type(flattened_prompt) is not Blend \
|
|
||||||
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
|
|
||||||
if wants_cross_attention_control:
|
|
||||||
original_prompt = FlattenedPrompt()
|
|
||||||
edited_prompt = FlattenedPrompt()
|
|
||||||
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
|
|
||||||
original_token_count = 0
|
|
||||||
edited_token_count = 0
|
|
||||||
edit_options = []
|
|
||||||
edit_opcodes = []
|
|
||||||
# beginning of sequence
|
|
||||||
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
|
|
||||||
edit_options.append(None)
|
|
||||||
original_token_count += 1
|
|
||||||
edited_token_count += 1
|
|
||||||
for fragment in flattened_prompt.children:
|
|
||||||
if type(fragment) is CrossAttentionControlSubstitute:
|
|
||||||
original_prompt.append(fragment.original)
|
|
||||||
edited_prompt.append(fragment.edited)
|
|
||||||
|
|
||||||
to_replace_token_count = get_tokens_length(model, fragment.original)
|
|
||||||
replacement_token_count = get_tokens_length(model, fragment.edited)
|
|
||||||
edit_opcodes.append(('replace',
|
|
||||||
original_token_count, original_token_count + to_replace_token_count,
|
|
||||||
edited_token_count, edited_token_count + replacement_token_count
|
|
||||||
))
|
|
||||||
original_token_count += to_replace_token_count
|
|
||||||
edited_token_count += replacement_token_count
|
|
||||||
edit_options.append(fragment.options)
|
|
||||||
#elif type(fragment) is CrossAttentionControlAppend:
|
|
||||||
# edited_prompt.append(fragment.fragment)
|
|
||||||
else:
|
|
||||||
# regular fragment
|
|
||||||
original_prompt.append(fragment)
|
|
||||||
edited_prompt.append(fragment)
|
|
||||||
|
|
||||||
count = get_tokens_length(model, [fragment])
|
|
||||||
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
|
|
||||||
edit_options.append(None)
|
|
||||||
original_token_count += count
|
|
||||||
edited_token_count += count
|
|
||||||
# end of sequence
|
|
||||||
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
|
|
||||||
edit_options.append(None)
|
|
||||||
original_token_count += 1
|
|
||||||
edited_token_count += 1
|
|
||||||
|
|
||||||
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
|
|
||||||
original_prompt,
|
|
||||||
log_tokens=log_tokens,
|
|
||||||
log_display_label="(.swap originals)")
|
|
||||||
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
|
||||||
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
|
||||||
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
|
||||||
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
|
||||||
# token 'smiling' in the inactive 'cat' edit.
|
|
||||||
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
|
||||||
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
|
|
||||||
edited_prompt,
|
|
||||||
log_tokens=log_tokens,
|
|
||||||
log_display_label="(.swap replacements)")
|
|
||||||
|
|
||||||
conditioning = original_embeddings
|
|
||||||
edited_conditioning = edited_embeddings
|
|
||||||
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
|
||||||
cac_args = cross_attention_control.Arguments(
|
|
||||||
edited_conditioning = edited_conditioning,
|
|
||||||
edit_opcodes = edit_opcodes,
|
|
||||||
edit_options = edit_options
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
conditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||||
flattened_prompt,
|
parsed_prompt,
|
||||||
log_tokens=log_tokens,
|
log_tokens=log_tokens,
|
||||||
log_display_label="(prompt)")
|
log_display_label="(prompt)")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"parsed_prompt is '{type(parsed_prompt)}' which is not a supported prompt type")
|
||||||
|
|
||||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
unconditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||||
parsed_negative_prompt,
|
parsed_negative_prompt,
|
||||||
log_tokens=log_tokens,
|
log_tokens=log_tokens,
|
||||||
log_display_label="(unconditioning)")
|
log_display_label="(unconditioning)")
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
# hybrid conditioning is in play
|
# hybrid conditioning is in play
|
||||||
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
|
unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning)
|
||||||
if cac_args is not None:
|
if cac_args is not None:
|
||||||
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
print(
|
||||||
|
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
||||||
cac_args = None
|
cac_args = None
|
||||||
|
|
||||||
|
eos_token_index = 1
|
||||||
|
if type(parsed_prompt) is not Blend:
|
||||||
|
tokens = get_tokens_for_prompt(model, parsed_prompt)
|
||||||
|
eos_token_index = len(tokens)+1
|
||||||
return (
|
return (
|
||||||
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
|
tokens_count_including_eos_bos=eos_token_index + 1,
|
||||||
cross_attention_control_args=cac_args
|
cross_attention_control_args=cac_args
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_token_edit_opcodes(original_tokens, edited_tokens):
|
def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True):
|
||||||
original_tokens = original_tokens.cpu().numpy()[0]
|
original_prompt = FlattenedPrompt()
|
||||||
edited_tokens = edited_tokens.cpu().numpy()[0]
|
edited_prompt = FlattenedPrompt()
|
||||||
|
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
|
||||||
|
original_token_count = 0
|
||||||
|
edited_token_count = 0
|
||||||
|
edit_options = []
|
||||||
|
edit_opcodes = []
|
||||||
|
# beginning of sequence
|
||||||
|
edit_opcodes.append(
|
||||||
|
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
|
||||||
|
edit_options.append(None)
|
||||||
|
original_token_count += 1
|
||||||
|
edited_token_count += 1
|
||||||
|
for fragment in prompt.children:
|
||||||
|
if type(fragment) is CrossAttentionControlSubstitute:
|
||||||
|
original_prompt.append(fragment.original)
|
||||||
|
edited_prompt.append(fragment.edited)
|
||||||
|
|
||||||
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
|
to_replace_token_count = _get_tokens_length(model, fragment.original)
|
||||||
|
replacement_token_count = _get_tokens_length(model, fragment.edited)
|
||||||
|
edit_opcodes.append(('replace',
|
||||||
|
original_token_count, original_token_count + to_replace_token_count,
|
||||||
|
edited_token_count, edited_token_count + replacement_token_count
|
||||||
|
))
|
||||||
|
original_token_count += to_replace_token_count
|
||||||
|
edited_token_count += replacement_token_count
|
||||||
|
edit_options.append(fragment.options)
|
||||||
|
# elif type(fragment) is CrossAttentionControlAppend:
|
||||||
|
# edited_prompt.append(fragment.fragment)
|
||||||
|
else:
|
||||||
|
# regular fragment
|
||||||
|
original_prompt.append(fragment)
|
||||||
|
edited_prompt.append(fragment)
|
||||||
|
|
||||||
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None):
|
count = _get_tokens_length(model, [fragment])
|
||||||
|
edit_opcodes.append(('equal', original_token_count, original_token_count + count, edited_token_count,
|
||||||
|
edited_token_count + count))
|
||||||
|
edit_options.append(None)
|
||||||
|
original_token_count += count
|
||||||
|
edited_token_count += count
|
||||||
|
# end of sequence
|
||||||
|
edit_opcodes.append(
|
||||||
|
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
|
||||||
|
edit_options.append(None)
|
||||||
|
original_token_count += 1
|
||||||
|
edited_token_count += 1
|
||||||
|
original_embeddings, original_tokens = _get_embeddings_and_tokens_for_prompt(model,
|
||||||
|
original_prompt,
|
||||||
|
log_tokens=log_tokens,
|
||||||
|
log_display_label="(.swap originals)")
|
||||||
|
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
||||||
|
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
||||||
|
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
||||||
|
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
||||||
|
# token 'smiling' in the inactive 'cat' edit.
|
||||||
|
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
||||||
|
edited_embeddings, edited_tokens = _get_embeddings_and_tokens_for_prompt(model,
|
||||||
|
edited_prompt,
|
||||||
|
log_tokens=log_tokens,
|
||||||
|
log_display_label="(.swap replacements)")
|
||||||
|
conditioning = original_embeddings
|
||||||
|
edited_conditioning = edited_embeddings
|
||||||
|
# print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
||||||
|
cac_args = cross_attention_control.Arguments(
|
||||||
|
edited_conditioning=edited_conditioning,
|
||||||
|
edit_opcodes=edit_opcodes,
|
||||||
|
edit_options=edit_options
|
||||||
|
)
|
||||||
|
return conditioning, cac_args
|
||||||
|
|
||||||
|
|
||||||
|
def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
|
||||||
|
embeddings_to_blend = None
|
||||||
|
for i, flattened_prompt in enumerate(blend.prompts):
|
||||||
|
this_embedding, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||||
|
flattened_prompt,
|
||||||
|
log_tokens=log_tokens,
|
||||||
|
log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
|
||||||
|
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||||
|
(embeddings_to_blend, this_embedding))
|
||||||
|
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||||
|
blend.weights,
|
||||||
|
normalize=blend.normalize_weights)
|
||||||
|
return conditioning
|
||||||
|
|
||||||
|
|
||||||
|
def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False,
|
||||||
|
log_display_label: str = None):
|
||||||
if type(flattened_prompt) is not FlattenedPrompt:
|
if type(flattened_prompt) is not FlattenedPrompt:
|
||||||
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
||||||
fragments = [x.text for x in flattened_prompt.children]
|
fragments = [x.text for x in flattened_prompt.children]
|
||||||
@ -181,12 +228,14 @@ def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: Fl
|
|||||||
|
|
||||||
return embeddings, tokens
|
return embeddings, tokens
|
||||||
|
|
||||||
def get_tokens_length(model, fragments: list[Fragment]):
|
|
||||||
|
def _get_tokens_length(model, fragments: list[Fragment]):
|
||||||
fragment_texts = [x.text for x in fragments]
|
fragment_texts = [x.text for x in fragments]
|
||||||
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
||||||
return sum([len(x) for x in tokens])
|
return sum([len(x) for x in tokens])
|
||||||
|
|
||||||
def flatten_hybrid_conditioning(uncond, cond):
|
|
||||||
|
def _flatten_hybrid_conditioning(uncond, cond):
|
||||||
'''
|
'''
|
||||||
This handles the choice between a conditional conditioning
|
This handles the choice between a conditional conditioning
|
||||||
that is a tensor (used by cross attention) vs one that has additional
|
that is a tensor (used by cross attention) vs one that has additional
|
||||||
@ -205,4 +254,29 @@ def flatten_hybrid_conditioning(uncond, cond):
|
|||||||
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
|
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
|
||||||
return uncond, cond_flattened
|
return uncond, cond_flattened
|
||||||
|
|
||||||
|
|
||||||
|
def log_tokenization(text, model, display_label=None):
|
||||||
|
""" shows how the prompt is tokenized
|
||||||
|
# usually tokens have '</w>' to indicate end-of-word,
|
||||||
|
# but for readability it has been replaced with ' '
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
||||||
|
tokenized = ""
|
||||||
|
discarded = ""
|
||||||
|
usedTokens = 0
|
||||||
|
totalTokens = len(tokens)
|
||||||
|
for i in range(0, totalTokens):
|
||||||
|
token = tokens[i].replace('</w>', ' ')
|
||||||
|
# alternate color
|
||||||
|
s = (usedTokens % 6) + 1
|
||||||
|
if i < model.cond_stage_model.max_length:
|
||||||
|
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||||
|
usedTokens += 1
|
||||||
|
else: # over max token length
|
||||||
|
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||||
|
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||||
|
if discarded != "":
|
||||||
|
print(
|
||||||
|
f">> Tokens Discarded ({totalTokens - usedTokens}):\n{discarded}\x1b[0m"
|
||||||
|
)
|
||||||
|
@ -14,6 +14,7 @@ import cv2 as cv
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from ldm.invoke.devices import choose_autocast
|
from ldm.invoke.devices import choose_autocast
|
||||||
|
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
from ldm.util import rand_perlin_2d
|
from ldm.util import rand_perlin_2d
|
||||||
|
|
||||||
downsampling = 8
|
downsampling = 8
|
||||||
@ -51,9 +52,12 @@ class Generator():
|
|||||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||||
safety_checker:dict=None,
|
safety_checker:dict=None,
|
||||||
|
attention_maps_callback = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
scope = choose_autocast(self.precision)
|
scope = choose_autocast(self.precision)
|
||||||
self.safety_checker = safety_checker
|
self.safety_checker = safety_checker
|
||||||
|
attention_maps_images = []
|
||||||
|
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||||
make_image = self.get_make_image(
|
make_image = self.get_make_image(
|
||||||
prompt,
|
prompt,
|
||||||
sampler = sampler,
|
sampler = sampler,
|
||||||
@ -63,6 +67,7 @@ class Generator():
|
|||||||
step_callback = step_callback,
|
step_callback = step_callback,
|
||||||
threshold = threshold,
|
threshold = threshold,
|
||||||
perlin = perlin,
|
perlin = perlin,
|
||||||
|
attention_maps_callback = attention_maps_callback,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
results = []
|
results = []
|
||||||
@ -98,12 +103,12 @@ class Generator():
|
|||||||
results.append([image, seed])
|
results.append([image, seed])
|
||||||
|
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
image_callback(image, seed, first_seed=first_seed)
|
image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_images[-1])
|
||||||
|
|
||||||
seed = self.new_seed()
|
seed = self.new_seed()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def sample_to_image(self,samples)->Image.Image:
|
def sample_to_image(self,samples)->Image.Image:
|
||||||
"""
|
"""
|
||||||
Given samples returned from a sampler, converts
|
Given samples returned from a sampler, converts
|
||||||
@ -166,12 +171,12 @@ class Generator():
|
|||||||
blurred_init_mask = pil_init_mask
|
blurred_init_mask = pil_init_mask
|
||||||
|
|
||||||
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
|
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
|
||||||
|
|
||||||
# Paste original on color-corrected generation (using blurred mask)
|
# Paste original on color-corrected generation (using blurred mask)
|
||||||
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask)
|
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask)
|
||||||
return matched_result
|
return matched_result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def sample_to_lowres_estimated_image(self,samples):
|
def sample_to_lowres_estimated_image(self,samples):
|
||||||
# origingally adapted from code by @erucipe and @keturn here:
|
# origingally adapted from code by @erucipe and @keturn here:
|
||||||
@ -219,11 +224,11 @@ class Generator():
|
|||||||
(txt2img) or from the latent image (img2img, inpaint)
|
(txt2img) or from the latent image (img2img, inpaint)
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("get_noise() must be implemented in a descendent class")
|
raise NotImplementedError("get_noise() must be implemented in a descendent class")
|
||||||
|
|
||||||
def get_perlin_noise(self,width,height):
|
def get_perlin_noise(self,width,height):
|
||||||
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
||||||
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
||||||
|
|
||||||
def new_seed(self):
|
def new_seed(self):
|
||||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||||
return self.seed
|
return self.seed
|
||||||
@ -325,4 +330,4 @@ class Generator():
|
|||||||
os.makedirs(dirname, exist_ok=True)
|
os.makedirs(dirname, exist_ok=True)
|
||||||
image.save(filepath,'PNG')
|
image.save(filepath,'PNG')
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +14,9 @@ class Txt2Img(Generator):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
|
||||||
|
attention_maps_callback=None,
|
||||||
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
Return value depends on the seed at the time you call it
|
Return value depends on the seed at the time you call it
|
||||||
@ -33,7 +35,7 @@ class Txt2Img(Generator):
|
|||||||
|
|
||||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||||
self.model.model.to(self.model.device)
|
self.model.model.to(self.model.device)
|
||||||
|
|
||||||
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
||||||
|
|
||||||
samples, _ = sampler.sample(
|
samples, _ = sampler.sample(
|
||||||
@ -49,6 +51,7 @@ class Txt2Img(Generator):
|
|||||||
eta = ddim_eta,
|
eta = ddim_eta,
|
||||||
img_callback = step_callback,
|
img_callback = step_callback,
|
||||||
threshold = threshold,
|
threshold = threshold,
|
||||||
|
attention_maps_callback = attention_maps_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.free_gpu_mem:
|
if self.free_gpu_mem:
|
||||||
|
@ -3,7 +3,7 @@ from typing import Union, Optional
|
|||||||
import re
|
import re
|
||||||
import pyparsing as pp
|
import pyparsing as pp
|
||||||
'''
|
'''
|
||||||
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
|
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
|
||||||
weighted subprompts.
|
weighted subprompts.
|
||||||
|
|
||||||
Useful class exports:
|
Useful class exports:
|
||||||
@ -69,6 +69,12 @@ class FlattenedPrompt():
|
|||||||
return len(self.children) == 0 or \
|
return len(self.children) == 0 or \
|
||||||
(len(self.children) == 1 and len(self.children[0].text) == 0)
|
(len(self.children) == 1 and len(self.children[0].text) == 0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wants_cross_attention_control(self):
|
||||||
|
return any(
|
||||||
|
[issubclass(type(x), CrossAttentionControlledFragment) for x in self.children]
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"FlattenedPrompt:{self.children}"
|
return f"FlattenedPrompt:{self.children}"
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
@ -240,6 +246,12 @@ class Blend():
|
|||||||
self.weights = weights
|
self.weights = weights
|
||||||
self.normalize_weights = normalize_weights
|
self.normalize_weights = normalize_weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wants_cross_attention_control(self):
|
||||||
|
# blends cannot cross-attention control
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
|
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
@ -277,8 +289,8 @@ class PromptParser():
|
|||||||
|
|
||||||
return self.flatten(root[0])
|
return self.flatten(root[0])
|
||||||
|
|
||||||
def parse_legacy_blend(self, text: str) -> Optional[Blend]:
|
def parse_legacy_blend(self, text: str, skip_normalize: bool) -> Optional[Blend]:
|
||||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False)
|
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||||
if len(weighted_subprompts) <= 1:
|
if len(weighted_subprompts) <= 1:
|
||||||
return None
|
return None
|
||||||
strings = [x[0] for x in weighted_subprompts]
|
strings = [x[0] for x in weighted_subprompts]
|
||||||
@ -287,7 +299,7 @@ class PromptParser():
|
|||||||
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
|
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
|
||||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||||
|
|
||||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
|
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
|
||||||
|
|
||||||
|
|
||||||
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
|
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
|
||||||
@ -641,27 +653,3 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
|
|||||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||||
|
|
||||||
|
|
||||||
# shows how the prompt is tokenized
|
|
||||||
# usually tokens have '</w>' to indicate end-of-word,
|
|
||||||
# but for readability it has been replaced with ' '
|
|
||||||
def log_tokenization(text, model, display_label=None):
|
|
||||||
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
|
||||||
tokenized = ""
|
|
||||||
discarded = ""
|
|
||||||
usedTokens = 0
|
|
||||||
totalTokens = len(tokens)
|
|
||||||
for i in range(0, totalTokens):
|
|
||||||
token = tokens[i].replace('</w>', ' ')
|
|
||||||
# alternate color
|
|
||||||
s = (usedTokens % 6) + 1
|
|
||||||
if i < model.cond_stage_model.max_length:
|
|
||||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
|
||||||
usedTokens += 1
|
|
||||||
else: # over max token length
|
|
||||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
|
||||||
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
|
|
||||||
if discarded != "":
|
|
||||||
print(
|
|
||||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
|
||||||
)
|
|
||||||
|
@ -53,7 +53,6 @@ COMMANDS = (
|
|||||||
'--codeformer_fidelity','-cf',
|
'--codeformer_fidelity','-cf',
|
||||||
'--upscale','-U',
|
'--upscale','-U',
|
||||||
'-save_orig','--save_original',
|
'-save_orig','--save_original',
|
||||||
'--skip_normalize','-x',
|
|
||||||
'--log_tokenization','-t',
|
'--log_tokenization','-t',
|
||||||
'--hires_fix',
|
'--hires_fix',
|
||||||
'--inpaint_replace','-r',
|
'--inpaint_replace','-r',
|
||||||
@ -117,19 +116,19 @@ class Completer(object):
|
|||||||
# extensions defined, so go directly into path completion mode
|
# extensions defined, so go directly into path completion mode
|
||||||
if self.extensions is not None:
|
if self.extensions is not None:
|
||||||
self.matches = self._path_completions(text, state, self.extensions)
|
self.matches = self._path_completions(text, state, self.extensions)
|
||||||
|
|
||||||
# looking for an image file
|
# looking for an image file
|
||||||
elif re.search(path_regexp,buffer):
|
elif re.search(path_regexp,buffer):
|
||||||
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
|
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
|
||||||
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
|
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
|
||||||
|
|
||||||
# looking for a seed
|
# looking for a seed
|
||||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||||
self.matches= self._seed_completions(text,state)
|
self.matches= self._seed_completions(text,state)
|
||||||
|
|
||||||
elif re.search('<[\w-]*$',buffer):
|
elif re.search('<[\w-]*$',buffer):
|
||||||
self.matches= self._concept_completions(text,state)
|
self.matches= self._concept_completions(text,state)
|
||||||
|
|
||||||
# looking for a model
|
# looking for a model
|
||||||
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||||
self.matches= self._model_completions(text, state)
|
self.matches= self._model_completions(text, state)
|
||||||
@ -227,7 +226,7 @@ class Completer(object):
|
|||||||
if h_len < 1:
|
if h_len < 1:
|
||||||
print('<empty history>')
|
print('<empty history>')
|
||||||
return
|
return
|
||||||
|
|
||||||
for i in range(0,h_len):
|
for i in range(0,h_len):
|
||||||
line = self.get_history_item(i+1)
|
line = self.get_history_item(i+1)
|
||||||
if match and match not in line:
|
if match and match not in line:
|
||||||
@ -367,7 +366,7 @@ class DummyCompleter(Completer):
|
|||||||
def __init__(self,options):
|
def __init__(self,options):
|
||||||
super().__init__(options)
|
super().__init__(options)
|
||||||
self.history = list()
|
self.history = list()
|
||||||
|
|
||||||
def add_history(self,line):
|
def add_history(self,line):
|
||||||
self.history.append(line)
|
self.history.append(line)
|
||||||
|
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
import enum
|
import enum
|
||||||
from typing import Optional
|
import math
|
||||||
|
from typing import Optional, Callable
|
||||||
|
|
||||||
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
# adapted from bloc97's CrossAttentionControl colab
|
# adapted from bloc97's CrossAttentionControl colab
|
||||||
# https://github.com/bloc97/CrossAttentionControl
|
# https://github.com/bloc97/CrossAttentionControl
|
||||||
|
|
||||||
|
|
||||||
class Arguments:
|
class Arguments:
|
||||||
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
||||||
"""
|
"""
|
||||||
@ -63,9 +65,13 @@ class Context:
|
|||||||
self.clear_requests(cleanup=True)
|
self.clear_requests(cleanup=True)
|
||||||
|
|
||||||
def register_cross_attention_modules(self, model):
|
def register_cross_attention_modules(self, model):
|
||||||
for name,module in get_attention_modules(model, CrossAttentionType.SELF):
|
for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||||
|
if name in self.self_cross_attention_module_identifiers:
|
||||||
|
assert False, f"name {name} cannot appear more than once"
|
||||||
self.self_cross_attention_module_identifiers.append(name)
|
self.self_cross_attention_module_identifiers.append(name)
|
||||||
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS):
|
for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||||
|
if name in self.tokens_cross_attention_module_identifiers:
|
||||||
|
assert False, f"name {name} cannot appear more than once"
|
||||||
self.tokens_cross_attention_module_identifiers.append(name)
|
self.tokens_cross_attention_module_identifiers.append(name)
|
||||||
|
|
||||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||||
@ -166,6 +172,135 @@ class Context:
|
|||||||
map_dict[offset] = slice.to('cpu')
|
map_dict[offset] = slice.to('cpu')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeAICrossAttentionMixin:
|
||||||
|
"""
|
||||||
|
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
||||||
|
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||||
|
and dymamic slicing strategy selection.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||||
|
self.attention_slice_wrangler = None
|
||||||
|
self.slicing_strategy_getter = None
|
||||||
|
self.attention_slice_calculated_callback = None
|
||||||
|
|
||||||
|
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
||||||
|
'''
|
||||||
|
Set custom attention calculator to be called when attention is calculated
|
||||||
|
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||||
|
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||||
|
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||||
|
`suggested_attention_slice` is the default-calculated attention slice
|
||||||
|
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||||
|
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||||
|
|
||||||
|
Pass None to use the default attention calculation.
|
||||||
|
:return:
|
||||||
|
'''
|
||||||
|
self.attention_slice_wrangler = wrangler
|
||||||
|
|
||||||
|
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
||||||
|
self.slicing_strategy_getter = getter
|
||||||
|
|
||||||
|
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
||||||
|
self.attention_slice_calculated_callback = callback
|
||||||
|
|
||||||
|
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||||
|
# calculate attention scores
|
||||||
|
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
||||||
|
attention_scores = torch.baddbmm(
|
||||||
|
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||||
|
query,
|
||||||
|
key.transpose(-1, -2),
|
||||||
|
beta=0,
|
||||||
|
alpha=self.scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate attention slice by taking the best scores for each latent pixel
|
||||||
|
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||||
|
attention_slice_wrangler = self.attention_slice_wrangler
|
||||||
|
if attention_slice_wrangler is not None:
|
||||||
|
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||||
|
else:
|
||||||
|
attention_slice = default_attention_slice
|
||||||
|
|
||||||
|
if self.attention_slice_calculated_callback is not None:
|
||||||
|
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
||||||
|
|
||||||
|
hidden_states = torch.bmm(attention_slice, value)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||||
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
for i in range(0, q.shape[0], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||||
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def einsum_op_mps_v1(self, q, k, v):
|
||||||
|
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||||
|
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||||
|
else:
|
||||||
|
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||||
|
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||||
|
|
||||||
|
def einsum_op_mps_v2(self, q, k, v):
|
||||||
|
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||||
|
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||||
|
else:
|
||||||
|
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||||
|
|
||||||
|
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||||
|
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||||
|
if size_mb <= max_tensor_mb:
|
||||||
|
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||||
|
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||||
|
if div <= q.shape[0]:
|
||||||
|
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||||
|
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||||
|
|
||||||
|
def einsum_op_cuda(self, q, k, v):
|
||||||
|
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||||
|
slicing_strategy_getter = self.slicing_strategy_getter
|
||||||
|
if slicing_strategy_getter is not None:
|
||||||
|
(dim, slice_size) = slicing_strategy_getter(self)
|
||||||
|
if dim is not None:
|
||||||
|
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||||
|
if dim == 0:
|
||||||
|
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||||
|
elif dim == 1:
|
||||||
|
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||||
|
|
||||||
|
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||||
|
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
|
||||||
|
# Divide factor of safety as there's copying and fragmentation
|
||||||
|
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||||
|
|
||||||
|
|
||||||
|
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
||||||
|
if q.device.type == 'cuda':
|
||||||
|
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||||
|
return self.einsum_op_cuda(q, k, v)
|
||||||
|
|
||||||
|
if q.device.type == 'mps' or q.device.type == 'cpu':
|
||||||
|
if self.mem_total_gb >= 32:
|
||||||
|
return self.einsum_op_mps_v1(q, k, v)
|
||||||
|
return self.einsum_op_mps_v2(q, k, v)
|
||||||
|
|
||||||
|
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||||
|
# Tested on i7 with 8MB L3 cache.
|
||||||
|
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def remove_cross_attention_control(model):
|
def remove_cross_attention_control(model):
|
||||||
remove_attention_function(model)
|
remove_attention_function(model)
|
||||||
|
|
||||||
@ -187,7 +322,7 @@ def setup_cross_attention_control(model, context: Context):
|
|||||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||||
mask = torch.zeros(max_length)
|
mask = torch.zeros(max_length)
|
||||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||||
indices = torch.zeros(max_length, dtype=torch.long)
|
indices = torch.arange(max_length, dtype=torch.long)
|
||||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||||
if b0 < max_length:
|
if b0 < max_length:
|
||||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||||
@ -201,10 +336,23 @@ def setup_cross_attention_control(model, context: Context):
|
|||||||
inject_attention_function(model, context)
|
inject_attention_function(model, context)
|
||||||
|
|
||||||
|
|
||||||
def get_attention_modules(model, which: CrossAttentionType):
|
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||||
|
cross_attention_class: type = InvokeAICrossAttentionMixin
|
||||||
|
# cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
||||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||||
return [(name,module) for name, module in model.named_modules() if
|
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
||||||
type(module).__name__ == "CrossAttention" and which_attn in name]
|
isinstance(module, cross_attention_class) and which_attn in name]
|
||||||
|
cross_attention_modules_in_model_count = len(attention_module_tuples)
|
||||||
|
expected_count = 16
|
||||||
|
if cross_attention_modules_in_model_count != expected_count:
|
||||||
|
# non-fatal error but .swap() won't work.
|
||||||
|
print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " +
|
||||||
|
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " +
|
||||||
|
f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " +
|
||||||
|
f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " +
|
||||||
|
f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " +
|
||||||
|
f"work properly until it is fixed.")
|
||||||
|
return attention_module_tuples
|
||||||
|
|
||||||
|
|
||||||
def inject_attention_function(unet, context: Context):
|
def inject_attention_function(unet, context: Context):
|
||||||
@ -244,19 +392,52 @@ def inject_attention_function(unet, context: Context):
|
|||||||
|
|
||||||
return attention_slice
|
return attention_slice
|
||||||
|
|
||||||
for name, module in unet.named_modules():
|
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||||
module_name = type(module).__name__
|
for identifier, module in cross_attention_modules:
|
||||||
if module_name == "CrossAttention":
|
module.identifier = identifier
|
||||||
module.identifier = name
|
try:
|
||||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||||
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \
|
module.set_slicing_strategy_getter(
|
||||||
context.get_slicing_strategy(module_identifier))
|
lambda module: context.get_slicing_strategy(identifier)
|
||||||
|
)
|
||||||
|
except AttributeError as e:
|
||||||
|
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
||||||
|
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def remove_attention_function(unet):
|
def remove_attention_function(unet):
|
||||||
# clear wrangler callback
|
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||||
for name, module in unet.named_modules():
|
for identifier, module in cross_attention_modules:
|
||||||
module_name = type(module).__name__
|
try:
|
||||||
if module_name == "CrossAttention":
|
# clear wrangler callback
|
||||||
module.set_attention_slice_wrangler(None)
|
module.set_attention_slice_wrangler(None)
|
||||||
module.set_slicing_strategy_getter(None)
|
module.set_slicing_strategy_getter(None)
|
||||||
|
except AttributeError as e:
|
||||||
|
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
||||||
|
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def is_attribute_error_about(error: AttributeError, attribute: str):
|
||||||
|
if hasattr(error, 'name'): # Python 3.10
|
||||||
|
return error.name == attribute
|
||||||
|
else: # Python 3.9
|
||||||
|
return attribute in str(error)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_mem_free_total(device):
|
||||||
|
#only on cuda
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return None
|
||||||
|
stats = torch.cuda.memory_stats(device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
return mem_free_total
|
||||||
|
|
||||||
|
95
ldm/models/diffusion/cross_attention_map_saving.py
Normal file
95
ldm/models/diffusion/cross_attention_map_saving.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode
|
||||||
|
|
||||||
|
from ldm.models.diffusion.cross_attention_control import get_cross_attention_modules, CrossAttentionType
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionMapSaver():
|
||||||
|
|
||||||
|
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
||||||
|
self.token_ids = token_ids
|
||||||
|
self.latents_shape = latents_shape
|
||||||
|
#self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||||
|
self.collated_maps = {}
|
||||||
|
|
||||||
|
def clear_maps(self):
|
||||||
|
self.collated_maps = {}
|
||||||
|
|
||||||
|
def add_attention_maps(self, maps: torch.Tensor, key: str):
|
||||||
|
"""
|
||||||
|
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
|
||||||
|
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
|
||||||
|
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
key_and_size = f'{key}_{maps.shape[1]}'
|
||||||
|
|
||||||
|
# extract desired tokens
|
||||||
|
maps = maps[:, :, self.token_ids]
|
||||||
|
|
||||||
|
# merge attention heads to a single map per token
|
||||||
|
maps = torch.sum(maps, 0)
|
||||||
|
|
||||||
|
# store
|
||||||
|
if key_and_size not in self.collated_maps:
|
||||||
|
self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu')
|
||||||
|
self.collated_maps[key_and_size] += maps.cpu()
|
||||||
|
|
||||||
|
def write_maps_to_disk(self, path: str):
|
||||||
|
pil_image = self.get_stacked_maps_image()
|
||||||
|
pil_image.save(path, 'PNG')
|
||||||
|
|
||||||
|
def get_stacked_maps_image(self) -> PIL.Image:
|
||||||
|
"""
|
||||||
|
Scale all collected attention maps to the same size, blend them together and return as an image.
|
||||||
|
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
||||||
|
"""
|
||||||
|
num_tokens = len(self.token_ids)
|
||||||
|
if num_tokens == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
latents_height = self.latents_shape[0]
|
||||||
|
latents_width = self.latents_shape[1]
|
||||||
|
|
||||||
|
merged = None
|
||||||
|
|
||||||
|
for key, maps in self.collated_maps.items():
|
||||||
|
|
||||||
|
# maps has shape [(H*W), N] for N tokens
|
||||||
|
# but we want [N, H, W]
|
||||||
|
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||||
|
this_maps_height = int(float(latents_height) * this_scale_factor)
|
||||||
|
this_maps_width = int(float(latents_width) * this_scale_factor)
|
||||||
|
# and we need to do some dimension juggling
|
||||||
|
maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width])
|
||||||
|
|
||||||
|
# scale to output size if necessary
|
||||||
|
if this_scale_factor != 1:
|
||||||
|
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
||||||
|
|
||||||
|
# normalize
|
||||||
|
maps_min = torch.min(maps)
|
||||||
|
maps_range = torch.max(maps) - maps_min
|
||||||
|
#print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
||||||
|
maps_normalized = (maps - maps_min) / maps_range
|
||||||
|
# expand to (-0.1, 1.1) and clamp
|
||||||
|
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
||||||
|
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
||||||
|
|
||||||
|
# merge together, producing a vertical stack
|
||||||
|
maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width])
|
||||||
|
|
||||||
|
if merged is None:
|
||||||
|
merged = maps_stacked
|
||||||
|
else:
|
||||||
|
# screen blend
|
||||||
|
merged = 1 - (1 - maps_stacked)*(1 - merged)
|
||||||
|
|
||||||
|
if merged is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
merged_bytes = merged.mul(0xff).byte()
|
||||||
|
return PIL.Image.fromarray(merged_bytes.numpy(), mode='L')
|
@ -4,6 +4,7 @@ import k_diffusion as K
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from .cross_attention_map_saving import AttentionMapSaver
|
||||||
from .sampler import Sampler
|
from .sampler import Sampler
|
||||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
|
||||||
@ -36,6 +37,7 @@ class CFGDenoiser(nn.Module):
|
|||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
||||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||||
|
|
||||||
|
|
||||||
def prepare_to_sample(self, t_enc, **kwargs):
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
|
|
||||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||||
@ -106,12 +108,12 @@ class KSampler(Sampler):
|
|||||||
else:
|
else:
|
||||||
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
|
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
|
||||||
self.sigmas = self.karras_sigmas
|
self.sigmas = self.karras_sigmas
|
||||||
|
|
||||||
# ALERT: We are completely overriding the sample() method in the base class, which
|
# ALERT: We are completely overriding the sample() method in the base class, which
|
||||||
# means that inpainting will not work. To get this to work we need to be able to
|
# means that inpainting will not work. To get this to work we need to be able to
|
||||||
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
|
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
|
||||||
# in the lstein/k-diffusion branch.
|
# in the lstein/k-diffusion branch.
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
@ -145,7 +147,7 @@ class KSampler(Sampler):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||||
return x0
|
return x0
|
||||||
|
|
||||||
# Most of these arguments are ignored and are only present for compatibility with
|
# Most of these arguments are ignored and are only present for compatibility with
|
||||||
# other samples
|
# other samples
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -158,6 +160,7 @@ class KSampler(Sampler):
|
|||||||
callback=None,
|
callback=None,
|
||||||
normals_sequence=None,
|
normals_sequence=None,
|
||||||
img_callback=None,
|
img_callback=None,
|
||||||
|
attention_maps_callback=None,
|
||||||
quantize_x0=False,
|
quantize_x0=False,
|
||||||
eta=0.0,
|
eta=0.0,
|
||||||
mask=None,
|
mask=None,
|
||||||
@ -171,7 +174,7 @@ class KSampler(Sampler):
|
|||||||
log_every_t=100,
|
log_every_t=100,
|
||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
extra_conditioning_info=None,
|
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
||||||
threshold = 0,
|
threshold = 0,
|
||||||
perlin = 0,
|
perlin = 0,
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
@ -204,6 +207,12 @@ class KSampler(Sampler):
|
|||||||
|
|
||||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||||
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||||
|
|
||||||
|
attention_map_token_ids = range(1, extra_conditioning_info.tokens_count_including_eos_bos - 1)
|
||||||
|
attention_maps_saver = None if attention_maps_callback is None else AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
||||||
|
if attention_maps_callback is not None:
|
||||||
|
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
|
||||||
|
|
||||||
extra_args = {
|
extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
@ -217,6 +226,8 @@ class KSampler(Sampler):
|
|||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
if attention_maps_callback is not None:
|
||||||
|
attention_maps_callback(attention_maps_saver)
|
||||||
return sampling_result
|
return sampling_result
|
||||||
|
|
||||||
# this code will support inpainting if and when ksampler API modified or
|
# this code will support inpainting if and when ksampler API modified or
|
||||||
@ -248,7 +259,7 @@ class KSampler(Sampler):
|
|||||||
# terrible, confusing names here
|
# terrible, confusing names here
|
||||||
steps = self.ddim_num_steps
|
steps = self.ddim_num_steps
|
||||||
t_enc = self.t_enc
|
t_enc = self.t_enc
|
||||||
|
|
||||||
# sigmas is a full steps in length, but t_enc might
|
# sigmas is a full steps in length, but t_enc might
|
||||||
# be less. We start in the middle of the sigma array
|
# be less. We start in the middle of the sigma array
|
||||||
# and work our way to the end after t_enc steps.
|
# and work our way to the end after t_enc steps.
|
||||||
@ -280,7 +291,7 @@ class KSampler(Sampler):
|
|||||||
return x_T + x
|
return x_T + x
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def prepare_to_sample(self,t_enc,**kwargs):
|
def prepare_to_sample(self,t_enc,**kwargs):
|
||||||
self.t_enc = t_enc
|
self.t_enc = t_enc
|
||||||
self.model_wrap = None
|
self.model_wrap = None
|
||||||
|
@ -5,8 +5,8 @@ from typing import Callable, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||||
remove_cross_attention_control, setup_cross_attention_control, Context
|
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, CrossAttentionType
|
||||||
from ldm.modules.attention import get_mem_free_total
|
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffuserComponent:
|
class InvokeAIDiffuserComponent:
|
||||||
@ -21,7 +21,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
|
|
||||||
class ExtraConditioningInfo:
|
class ExtraConditioningInfo:
|
||||||
def __init__(self, cross_attention_control_args: Optional[Arguments]):
|
def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]):
|
||||||
|
self.tokens_count_including_eos_bos = tokens_count_including_eos_bos
|
||||||
self.cross_attention_control_args = cross_attention_control_args
|
self.cross_attention_control_args = cross_attention_control_args
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -52,7 +53,25 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
remove_cross_attention_control(self.model)
|
remove_cross_attention_control(self.model)
|
||||||
|
|
||||||
|
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||||
|
def callback(slice, dim, offset, slice_size, key):
|
||||||
|
if dim is not None:
|
||||||
|
# sliced tokens attention map saving is not implemented
|
||||||
|
return
|
||||||
|
saver.add_attention_maps(slice, key)
|
||||||
|
|
||||||
|
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||||
|
for identifier, module in tokens_cross_attention_modules:
|
||||||
|
key = ('down' if identifier.startswith('down') else
|
||||||
|
'up' if identifier.startswith('up') else
|
||||||
|
'mid')
|
||||||
|
module.set_attention_slice_calculated_callback(
|
||||||
|
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key))
|
||||||
|
|
||||||
|
def remove_attention_map_saving(self):
|
||||||
|
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||||
|
for _, module in tokens_cross_attention_modules:
|
||||||
|
module.set_attention_slice_calculated_callback(None)
|
||||||
|
|
||||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||||
unconditioning: Union[torch.Tensor,dict],
|
unconditioning: Union[torch.Tensor,dict],
|
||||||
|
@ -7,10 +7,9 @@ import torch.nn.functional as F
|
|||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin
|
||||||
from ldm.modules.diffusionmodules.util import checkpoint
|
from ldm.modules.diffusionmodules.util import checkpoint
|
||||||
|
|
||||||
import psutil
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
@ -164,9 +163,10 @@ def get_mem_free_total(device):
|
|||||||
return mem_free_total
|
return mem_free_total
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
InvokeAICrossAttentionMixin.__init__(self)
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
|
|
||||||
@ -182,118 +182,6 @@ class CrossAttention(nn.Module):
|
|||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
|
||||||
|
|
||||||
self.cached_mem_free_total = None
|
|
||||||
self.attention_slice_wrangler = None
|
|
||||||
self.slicing_strategy_getter = None
|
|
||||||
|
|
||||||
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
|
||||||
'''
|
|
||||||
Set custom attention calculator to be called when attention is calculated
|
|
||||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
|
||||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
|
||||||
`module` is the current CrossAttention module for which the callback is being invoked.
|
|
||||||
`suggested_attention_slice` is the default-calculated attention slice
|
|
||||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
|
||||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
|
||||||
|
|
||||||
Pass None to use the default attention calculation.
|
|
||||||
:return:
|
|
||||||
'''
|
|
||||||
self.attention_slice_wrangler = wrangler
|
|
||||||
|
|
||||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
|
||||||
self.slicing_strategy_getter = getter
|
|
||||||
|
|
||||||
def cache_free_memory_count(self, device):
|
|
||||||
self.cached_mem_free_total = get_mem_free_total(device)
|
|
||||||
print("free cuda memory: ", self.cached_mem_free_total)
|
|
||||||
|
|
||||||
def clear_cached_free_memory_count(self):
|
|
||||||
self.cached_mem_free_total = None
|
|
||||||
|
|
||||||
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
|
|
||||||
# calculate attention scores
|
|
||||||
attention_scores = einsum('b i d, b j d -> b i j', q, k)
|
|
||||||
# calculate attention slice by taking the best scores for each latent pixel
|
|
||||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
|
||||||
attention_slice_wrangler = self.attention_slice_wrangler
|
|
||||||
if attention_slice_wrangler is not None:
|
|
||||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
|
||||||
else:
|
|
||||||
attention_slice = default_attention_slice
|
|
||||||
|
|
||||||
return einsum('b i j, b j d -> b i d', attention_slice, v)
|
|
||||||
|
|
||||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
||||||
for i in range(0, q.shape[0], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def einsum_op_mps_v1(self, q, k, v):
|
|
||||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
else:
|
|
||||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
|
||||||
|
|
||||||
def einsum_op_mps_v2(self, q, k, v):
|
|
||||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
else:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
|
||||||
|
|
||||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
|
||||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
|
||||||
if size_mb <= max_tensor_mb:
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
|
||||||
if div <= q.shape[0]:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
|
||||||
|
|
||||||
def einsum_op_cuda(self, q, k, v):
|
|
||||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
|
||||||
slicing_strategy_getter = self.slicing_strategy_getter
|
|
||||||
if slicing_strategy_getter is not None:
|
|
||||||
(dim, slice_size) = slicing_strategy_getter(self)
|
|
||||||
if dim is not None:
|
|
||||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
|
||||||
if dim == 0:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
|
||||||
elif dim == 1:
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
|
||||||
|
|
||||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
|
||||||
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
|
|
||||||
# Divide factor of safety as there's copying and fragmentation
|
|
||||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
|
||||||
|
|
||||||
|
|
||||||
def get_attention_mem_efficient(self, q, k, v):
|
|
||||||
if q.device.type == 'cuda':
|
|
||||||
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
|
||||||
return self.einsum_op_cuda(q, k, v)
|
|
||||||
|
|
||||||
if q.device.type == 'mps':
|
|
||||||
if self.mem_total_gb >= 32:
|
|
||||||
return self.einsum_op_mps_v1(q, k, v)
|
|
||||||
return self.einsum_op_mps_v2(q, k, v)
|
|
||||||
|
|
||||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
|
||||||
# Tested on i7 with 8MB L3 cache.
|
|
||||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
@ -305,7 +193,11 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
r = self.get_attention_mem_efficient(q, k, v)
|
# don't apply scale twice
|
||||||
|
cached_scale = self.scale
|
||||||
|
self.scale = 1
|
||||||
|
r = self.get_invokeai_attention_mem_efficient(q, k, v)
|
||||||
|
self.scale = cached_scale
|
||||||
|
|
||||||
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
|
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
|
||||||
return self.to_out(hidden_states)
|
return self.to_out(hidden_states)
|
||||||
|
Loading…
Reference in New Issue
Block a user