mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge remote-tracking branch 'origin/main' into dev/diffusers
# Conflicts: # backend/invoke_ai_web_server.py # ldm/generate.py # ldm/invoke/CLI.py # ldm/invoke/generator/base.py # ldm/invoke/generator/txt2img.py # ldm/models/diffusion/cross_attention_control.py # ldm/modules/attention.py
This commit is contained in:
commit
63532226a5
@ -21,6 +21,7 @@ from backend.modules.get_canvas_generation_mode import (
|
|||||||
get_canvas_generation_mode,
|
get_canvas_generation_mode,
|
||||||
)
|
)
|
||||||
from backend.modules.parameters import parameters_to_command
|
from backend.modules.parameters import parameters_to_command
|
||||||
|
from ldm.generate import Generate
|
||||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||||
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
||||||
from ldm.invoke.generator.inpaint import infill_methods
|
from ldm.invoke.generator.inpaint import infill_methods
|
||||||
@ -38,7 +39,7 @@ if not os.path.isabs(args.outdir):
|
|||||||
|
|
||||||
|
|
||||||
class InvokeAIWebServer:
|
class InvokeAIWebServer:
|
||||||
def __init__(self, generate, gfpgan, codeformer, esrgan) -> None:
|
def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None:
|
||||||
self.host = args.host
|
self.host = args.host
|
||||||
self.port = args.port
|
self.port = args.port
|
||||||
|
|
||||||
@ -906,16 +907,13 @@ class InvokeAIWebServer:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if generation_parameters["progress_latents"]:
|
if generation_parameters["progress_latents"]:
|
||||||
image = self.generate.sample_to_lowres_estimated_image(sample)
|
image = self.generate.sample_to_lowres_estimated_image(sample)
|
||||||
(width, height) = image.size
|
(width, height) = image.size
|
||||||
width *= 8
|
width *= 8
|
||||||
height *= 8
|
height *= 8
|
||||||
buffered = io.BytesIO()
|
img_base64 = image_to_dataURL(image)
|
||||||
image.save(buffered, format="PNG")
|
|
||||||
img_base64 = "data:image/png;base64," + base64.b64encode(
|
|
||||||
buffered.getvalue()
|
|
||||||
).decode("UTF-8")
|
|
||||||
self.socketio.emit(
|
self.socketio.emit(
|
||||||
"intermediateResult",
|
"intermediateResult",
|
||||||
{
|
{
|
||||||
@ -933,7 +931,7 @@ class InvokeAIWebServer:
|
|||||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
|
|
||||||
def image_done(image, seed, first_seed):
|
def image_done(image, seed, first_seed, attention_maps_image=None):
|
||||||
if self.canceled.is_set():
|
if self.canceled.is_set():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
@ -1095,6 +1093,12 @@ class InvokeAIWebServer:
|
|||||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
|
|
||||||
|
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
|
||||||
|
tokens = None if type(parsed_prompt) is Blend else \
|
||||||
|
get_tokens_for_prompt(self.generate.model, parsed_prompt)
|
||||||
|
attention_maps_image_base64_url = None if attention_maps_image is None \
|
||||||
|
else image_to_dataURL(attention_maps_image)
|
||||||
|
|
||||||
self.socketio.emit(
|
self.socketio.emit(
|
||||||
"generationResult",
|
"generationResult",
|
||||||
{
|
{
|
||||||
@ -1107,6 +1111,8 @@ class InvokeAIWebServer:
|
|||||||
"height": height,
|
"height": height,
|
||||||
"boundingBox": original_bounding_box,
|
"boundingBox": original_bounding_box,
|
||||||
"generationMode": generation_parameters["generation_mode"],
|
"generationMode": generation_parameters["generation_mode"],
|
||||||
|
"attentionMaps": attention_maps_image_base64_url,
|
||||||
|
"tokens": tokens,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
@ -1118,7 +1124,7 @@ class InvokeAIWebServer:
|
|||||||
self.generate.prompt2image(
|
self.generate.prompt2image(
|
||||||
**generation_parameters,
|
**generation_parameters,
|
||||||
step_callback=image_progress,
|
step_callback=image_progress,
|
||||||
image_callback=image_done,
|
image_callback=image_done
|
||||||
)
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@ -1565,6 +1571,19 @@ def dataURL_to_image(dataURL: str) -> ImageType:
|
|||||||
)
|
)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
"""
|
||||||
|
Converts an image into a base64 image dataURL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def image_to_dataURL(image: ImageType) -> str:
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
image.save(buffered, format="PNG")
|
||||||
|
image_base64 = "data:image/png;base64," + base64.b64encode(
|
||||||
|
buffered.getvalue()
|
||||||
|
).decode("UTF-8")
|
||||||
|
return image_base64
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Converts a base64 image dataURL into bytes.
|
Converts a base64 image dataURL into bytes.
|
||||||
|
@ -16,7 +16,6 @@ import numpy as np
|
|||||||
import skimage
|
import skimage
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from PIL import Image, ImageOps
|
|
||||||
from diffusers import HeunDiscreteScheduler
|
from diffusers import HeunDiscreteScheduler
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
@ -27,6 +26,7 @@ from diffusers.schedulers.scheduling_ipndm import IPNDMScheduler
|
|||||||
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
|
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
|
||||||
from diffusers.schedulers.scheduling_pndm import PNDMScheduler
|
from diffusers.schedulers.scheduling_pndm import PNDMScheduler
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image, ImageOps
|
||||||
from pytorch_lightning import seed_everything, logging
|
from pytorch_lightning import seed_everything, logging
|
||||||
|
|
||||||
from ldm.invoke.args import metadata_from_png
|
from ldm.invoke.args import metadata_from_png
|
||||||
@ -461,7 +461,7 @@ class Generate:
|
|||||||
try:
|
try:
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||||
prompt, model =self.model,
|
prompt, model =self.model,
|
||||||
skip_normalize=skip_normalize,
|
skip_normalize_legacy_blend=skip_normalize,
|
||||||
log_tokens =self.log_tokenization
|
log_tokens =self.log_tokenization
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -613,8 +613,8 @@ class Generate:
|
|||||||
# todo: cross-attention control
|
# todo: cross-attention control
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||||
prompt, model =self.model,
|
prompt, model =self.model,
|
||||||
skip_normalize=opt.skip_normalize,
|
skip_normalize_legacy_blend=opt.skip_normalize,
|
||||||
log_tokens =opt.log_tokenization
|
log_tokens =ldm.invoke.conditioning.log_tokenization
|
||||||
)
|
)
|
||||||
|
|
||||||
if tool in ('gfpgan','codeformer','upscale'):
|
if tool in ('gfpgan','codeformer','upscale'):
|
||||||
|
@ -8,6 +8,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from ldm.generate import Generate
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.invoke.prompt_parser import PromptParser
|
from ldm.invoke.prompt_parser import PromptParser
|
||||||
from ldm.invoke.readline import get_completer, Completer
|
from ldm.invoke.readline import get_completer, Completer
|
||||||
@ -282,7 +283,7 @@ def main_loop(gen, opt):
|
|||||||
prefix = file_writer.unique_prefix()
|
prefix = file_writer.unique_prefix()
|
||||||
step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None
|
step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None
|
||||||
|
|
||||||
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None):
|
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None, attention_maps_image=None):
|
||||||
# note the seed is the seed of the current image
|
# note the seed is the seed of the current image
|
||||||
# the first_seed is the original seed that noise is added to
|
# the first_seed is the original seed that noise is added to
|
||||||
# when the -v switch is used to generate variations
|
# when the -v switch is used to generate variations
|
||||||
@ -790,7 +791,7 @@ def get_next_command(infile=None) -> str: # command string
|
|||||||
print(f'#{command}')
|
print(f'#{command}')
|
||||||
return command
|
return command
|
||||||
|
|
||||||
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
|
def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan):
|
||||||
print('\n* --web was specified, starting web server...')
|
print('\n* --web was specified, starting web server...')
|
||||||
from backend.invoke_ai_web_server import InvokeAIWebServer
|
from backend.invoke_ai_web_server import InvokeAIWebServer
|
||||||
# Change working directory to the stable-diffusion directory
|
# Change working directory to the stable-diffusion directory
|
||||||
|
@ -7,20 +7,46 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
import re
|
import re
|
||||||
from difflib import SequenceMatcher
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
||||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
|
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment
|
||||||
from ..models.diffusion import cross_attention_control
|
from ..models.diffusion import cross_attention_control
|
||||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
|
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
||||||
|
prompt, negative_prompt = get_prompt_structure(prompt_string,
|
||||||
|
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
||||||
|
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
|
||||||
|
|
||||||
|
return conditioning
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> (
|
||||||
|
Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||||
|
"""
|
||||||
|
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
|
||||||
|
"""
|
||||||
|
prompt, negative_prompt = _parse_prompt_string(prompt_string,
|
||||||
|
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
||||||
|
return prompt, negative_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]:
|
||||||
|
text_fragments = [x.text if type(x) is Fragment else
|
||||||
|
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
|
||||||
|
str(x))
|
||||||
|
for x in parsed_prompt.children]
|
||||||
|
text = " ".join(text_fragments)
|
||||||
|
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]:
|
||||||
# Extract Unconditioned Words From Prompt
|
# Extract Unconditioned Words From Prompt
|
||||||
unconditioned_words = ''
|
unconditioned_words = ''
|
||||||
unconditional_regex = r'\[(.*?)\]'
|
unconditional_regex = r'\[(.*?)\]'
|
||||||
@ -39,7 +65,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
pp = PromptParser()
|
pp = PromptParser()
|
||||||
|
|
||||||
parsed_prompt: Union[FlattenedPrompt, Blend] = None
|
parsed_prompt: Union[FlattenedPrompt, Blend] = None
|
||||||
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned)
|
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned, skip_normalize_legacy_blend)
|
||||||
if legacy_blend is not None:
|
if legacy_blend is not None:
|
||||||
parsed_prompt = legacy_blend
|
parsed_prompt = legacy_blend
|
||||||
else:
|
else:
|
||||||
@ -47,129 +73,150 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
|
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
|
||||||
|
|
||||||
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
|
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
|
||||||
|
return parsed_prompt, parsed_negative_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt,
|
||||||
|
model, log_tokens=False) \
|
||||||
|
-> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]:
|
||||||
|
"""
|
||||||
|
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
|
||||||
|
"""
|
||||||
|
|
||||||
if log_tokens:
|
if log_tokens:
|
||||||
print(f">> Parsed prompt to {parsed_prompt}")
|
print(f">> Parsed prompt to {parsed_prompt}")
|
||||||
print(f">> Parsed negative prompt to {parsed_negative_prompt}")
|
print(f">> Parsed negative prompt to {parsed_negative_prompt}")
|
||||||
|
|
||||||
conditioning = None
|
conditioning = None
|
||||||
cac_args:cross_attention_control.Arguments = None
|
cac_args: cross_attention_control.Arguments = None
|
||||||
|
|
||||||
if type(parsed_prompt) is Blend:
|
if type(parsed_prompt) is Blend:
|
||||||
blend: Blend = parsed_prompt
|
conditioning = _get_conditioning_for_blend(model, parsed_prompt, log_tokens)
|
||||||
embeddings_to_blend = None
|
elif type(parsed_prompt) is FlattenedPrompt:
|
||||||
for i,flattened_prompt in enumerate(blend.prompts):
|
if parsed_prompt.wants_cross_attention_control:
|
||||||
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
conditioning, cac_args = _get_conditioning_for_cross_attention_control(model, parsed_prompt, log_tokens)
|
||||||
flattened_prompt,
|
|
||||||
log_tokens=log_tokens,
|
|
||||||
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
|
|
||||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
|
||||||
(embeddings_to_blend, this_embedding))
|
|
||||||
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
|
||||||
blend.weights,
|
|
||||||
normalize=blend.normalize_weights)
|
|
||||||
else:
|
|
||||||
flattened_prompt: FlattenedPrompt = parsed_prompt
|
|
||||||
wants_cross_attention_control = type(flattened_prompt) is not Blend \
|
|
||||||
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
|
|
||||||
if wants_cross_attention_control:
|
|
||||||
original_prompt = FlattenedPrompt()
|
|
||||||
edited_prompt = FlattenedPrompt()
|
|
||||||
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
|
|
||||||
original_token_count = 0
|
|
||||||
edited_token_count = 0
|
|
||||||
edit_options = []
|
|
||||||
edit_opcodes = []
|
|
||||||
# beginning of sequence
|
|
||||||
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
|
|
||||||
edit_options.append(None)
|
|
||||||
original_token_count += 1
|
|
||||||
edited_token_count += 1
|
|
||||||
for fragment in flattened_prompt.children:
|
|
||||||
if type(fragment) is CrossAttentionControlSubstitute:
|
|
||||||
original_prompt.append(fragment.original)
|
|
||||||
edited_prompt.append(fragment.edited)
|
|
||||||
|
|
||||||
to_replace_token_count = get_tokens_length(model, fragment.original)
|
|
||||||
replacement_token_count = get_tokens_length(model, fragment.edited)
|
|
||||||
edit_opcodes.append(('replace',
|
|
||||||
original_token_count, original_token_count + to_replace_token_count,
|
|
||||||
edited_token_count, edited_token_count + replacement_token_count
|
|
||||||
))
|
|
||||||
original_token_count += to_replace_token_count
|
|
||||||
edited_token_count += replacement_token_count
|
|
||||||
edit_options.append(fragment.options)
|
|
||||||
#elif type(fragment) is CrossAttentionControlAppend:
|
|
||||||
# edited_prompt.append(fragment.fragment)
|
|
||||||
else:
|
|
||||||
# regular fragment
|
|
||||||
original_prompt.append(fragment)
|
|
||||||
edited_prompt.append(fragment)
|
|
||||||
|
|
||||||
count = get_tokens_length(model, [fragment])
|
|
||||||
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
|
|
||||||
edit_options.append(None)
|
|
||||||
original_token_count += count
|
|
||||||
edited_token_count += count
|
|
||||||
# end of sequence
|
|
||||||
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
|
|
||||||
edit_options.append(None)
|
|
||||||
original_token_count += 1
|
|
||||||
edited_token_count += 1
|
|
||||||
|
|
||||||
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
|
|
||||||
original_prompt,
|
|
||||||
log_tokens=log_tokens,
|
|
||||||
log_display_label="(.swap originals)")
|
|
||||||
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
|
||||||
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
|
||||||
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
|
||||||
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
|
||||||
# token 'smiling' in the inactive 'cat' edit.
|
|
||||||
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
|
||||||
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
|
|
||||||
edited_prompt,
|
|
||||||
log_tokens=log_tokens,
|
|
||||||
log_display_label="(.swap replacements)")
|
|
||||||
|
|
||||||
conditioning = original_embeddings
|
|
||||||
edited_conditioning = edited_embeddings
|
|
||||||
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
|
||||||
cac_args = cross_attention_control.Arguments(
|
|
||||||
edited_conditioning = edited_conditioning,
|
|
||||||
edit_opcodes = edit_opcodes,
|
|
||||||
edit_options = edit_options
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
conditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||||
flattened_prompt,
|
parsed_prompt,
|
||||||
log_tokens=log_tokens,
|
log_tokens=log_tokens,
|
||||||
log_display_label="(prompt)")
|
log_display_label="(prompt)")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"parsed_prompt is '{type(parsed_prompt)}' which is not a supported prompt type")
|
||||||
|
|
||||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
|
unconditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||||
parsed_negative_prompt,
|
parsed_negative_prompt,
|
||||||
log_tokens=log_tokens,
|
log_tokens=log_tokens,
|
||||||
log_display_label="(unconditioning)")
|
log_display_label="(unconditioning)")
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
# hybrid conditioning is in play
|
# hybrid conditioning is in play
|
||||||
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
|
unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning)
|
||||||
if cac_args is not None:
|
if cac_args is not None:
|
||||||
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
print(
|
||||||
|
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
||||||
cac_args = None
|
cac_args = None
|
||||||
|
|
||||||
|
eos_token_index = 1
|
||||||
|
if type(parsed_prompt) is not Blend:
|
||||||
|
tokens = get_tokens_for_prompt(model, parsed_prompt)
|
||||||
|
eos_token_index = len(tokens)+1
|
||||||
return (
|
return (
|
||||||
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
|
tokens_count_including_eos_bos=eos_token_index + 1,
|
||||||
cross_attention_control_args=cac_args
|
cross_attention_control_args=cac_args
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_token_edit_opcodes(original_tokens, edited_tokens):
|
def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True):
|
||||||
original_tokens = original_tokens.cpu().numpy()[0]
|
original_prompt = FlattenedPrompt()
|
||||||
edited_tokens = edited_tokens.cpu().numpy()[0]
|
edited_prompt = FlattenedPrompt()
|
||||||
|
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
|
||||||
|
original_token_count = 0
|
||||||
|
edited_token_count = 0
|
||||||
|
edit_options = []
|
||||||
|
edit_opcodes = []
|
||||||
|
# beginning of sequence
|
||||||
|
edit_opcodes.append(
|
||||||
|
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
|
||||||
|
edit_options.append(None)
|
||||||
|
original_token_count += 1
|
||||||
|
edited_token_count += 1
|
||||||
|
for fragment in prompt.children:
|
||||||
|
if type(fragment) is CrossAttentionControlSubstitute:
|
||||||
|
original_prompt.append(fragment.original)
|
||||||
|
edited_prompt.append(fragment.edited)
|
||||||
|
|
||||||
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
|
to_replace_token_count = _get_tokens_length(model, fragment.original)
|
||||||
|
replacement_token_count = _get_tokens_length(model, fragment.edited)
|
||||||
|
edit_opcodes.append(('replace',
|
||||||
|
original_token_count, original_token_count + to_replace_token_count,
|
||||||
|
edited_token_count, edited_token_count + replacement_token_count
|
||||||
|
))
|
||||||
|
original_token_count += to_replace_token_count
|
||||||
|
edited_token_count += replacement_token_count
|
||||||
|
edit_options.append(fragment.options)
|
||||||
|
# elif type(fragment) is CrossAttentionControlAppend:
|
||||||
|
# edited_prompt.append(fragment.fragment)
|
||||||
|
else:
|
||||||
|
# regular fragment
|
||||||
|
original_prompt.append(fragment)
|
||||||
|
edited_prompt.append(fragment)
|
||||||
|
|
||||||
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None):
|
count = _get_tokens_length(model, [fragment])
|
||||||
|
edit_opcodes.append(('equal', original_token_count, original_token_count + count, edited_token_count,
|
||||||
|
edited_token_count + count))
|
||||||
|
edit_options.append(None)
|
||||||
|
original_token_count += count
|
||||||
|
edited_token_count += count
|
||||||
|
# end of sequence
|
||||||
|
edit_opcodes.append(
|
||||||
|
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
|
||||||
|
edit_options.append(None)
|
||||||
|
original_token_count += 1
|
||||||
|
edited_token_count += 1
|
||||||
|
original_embeddings, original_tokens = _get_embeddings_and_tokens_for_prompt(model,
|
||||||
|
original_prompt,
|
||||||
|
log_tokens=log_tokens,
|
||||||
|
log_display_label="(.swap originals)")
|
||||||
|
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
||||||
|
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
||||||
|
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
||||||
|
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
||||||
|
# token 'smiling' in the inactive 'cat' edit.
|
||||||
|
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
||||||
|
edited_embeddings, edited_tokens = _get_embeddings_and_tokens_for_prompt(model,
|
||||||
|
edited_prompt,
|
||||||
|
log_tokens=log_tokens,
|
||||||
|
log_display_label="(.swap replacements)")
|
||||||
|
conditioning = original_embeddings
|
||||||
|
edited_conditioning = edited_embeddings
|
||||||
|
# print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
||||||
|
cac_args = cross_attention_control.Arguments(
|
||||||
|
edited_conditioning=edited_conditioning,
|
||||||
|
edit_opcodes=edit_opcodes,
|
||||||
|
edit_options=edit_options
|
||||||
|
)
|
||||||
|
return conditioning, cac_args
|
||||||
|
|
||||||
|
|
||||||
|
def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
|
||||||
|
embeddings_to_blend = None
|
||||||
|
for i, flattened_prompt in enumerate(blend.prompts):
|
||||||
|
this_embedding, _ = _get_embeddings_and_tokens_for_prompt(model,
|
||||||
|
flattened_prompt,
|
||||||
|
log_tokens=log_tokens,
|
||||||
|
log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
|
||||||
|
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||||
|
(embeddings_to_blend, this_embedding))
|
||||||
|
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||||
|
blend.weights,
|
||||||
|
normalize=blend.normalize_weights)
|
||||||
|
return conditioning
|
||||||
|
|
||||||
|
|
||||||
|
def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False,
|
||||||
|
log_display_label: str = None):
|
||||||
if type(flattened_prompt) is not FlattenedPrompt:
|
if type(flattened_prompt) is not FlattenedPrompt:
|
||||||
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
||||||
fragments = [x.text for x in flattened_prompt.children]
|
fragments = [x.text for x in flattened_prompt.children]
|
||||||
@ -181,12 +228,14 @@ def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: Fl
|
|||||||
|
|
||||||
return embeddings, tokens
|
return embeddings, tokens
|
||||||
|
|
||||||
def get_tokens_length(model, fragments: list[Fragment]):
|
|
||||||
|
def _get_tokens_length(model, fragments: list[Fragment]):
|
||||||
fragment_texts = [x.text for x in fragments]
|
fragment_texts = [x.text for x in fragments]
|
||||||
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
||||||
return sum([len(x) for x in tokens])
|
return sum([len(x) for x in tokens])
|
||||||
|
|
||||||
def flatten_hybrid_conditioning(uncond, cond):
|
|
||||||
|
def _flatten_hybrid_conditioning(uncond, cond):
|
||||||
'''
|
'''
|
||||||
This handles the choice between a conditional conditioning
|
This handles the choice between a conditional conditioning
|
||||||
that is a tensor (used by cross attention) vs one that has additional
|
that is a tensor (used by cross attention) vs one that has additional
|
||||||
@ -205,4 +254,29 @@ def flatten_hybrid_conditioning(uncond, cond):
|
|||||||
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
|
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
|
||||||
return uncond, cond_flattened
|
return uncond, cond_flattened
|
||||||
|
|
||||||
|
|
||||||
|
def log_tokenization(text, model, display_label=None):
|
||||||
|
""" shows how the prompt is tokenized
|
||||||
|
# usually tokens have '</w>' to indicate end-of-word,
|
||||||
|
# but for readability it has been replaced with ' '
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
||||||
|
tokenized = ""
|
||||||
|
discarded = ""
|
||||||
|
usedTokens = 0
|
||||||
|
totalTokens = len(tokens)
|
||||||
|
for i in range(0, totalTokens):
|
||||||
|
token = tokens[i].replace('</w>', ' ')
|
||||||
|
# alternate color
|
||||||
|
s = (usedTokens % 6) + 1
|
||||||
|
if i < model.cond_stage_model.max_length:
|
||||||
|
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||||
|
usedTokens += 1
|
||||||
|
else: # over max token length
|
||||||
|
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||||
|
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||||
|
if discarded != "":
|
||||||
|
print(
|
||||||
|
f">> Tokens Discarded ({totalTokens - usedTokens}):\n{discarded}\x1b[0m"
|
||||||
|
)
|
||||||
|
@ -19,6 +19,7 @@ from pytorch_lightning import seed_everything
|
|||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from ldm.invoke.devices import choose_autocast
|
from ldm.invoke.devices import choose_autocast
|
||||||
|
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
from ldm.models.diffusion.ddpm import DiffusionWrapper
|
from ldm.models.diffusion.ddpm import DiffusionWrapper
|
||||||
from ldm.util import rand_perlin_2d
|
from ldm.util import rand_perlin_2d
|
||||||
|
|
||||||
@ -62,9 +63,12 @@ class Generator:
|
|||||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||||
safety_checker:dict=None,
|
safety_checker:dict=None,
|
||||||
|
attention_maps_callback = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
scope = choose_autocast(self.precision)
|
scope = choose_autocast(self.precision)
|
||||||
self.safety_checker = safety_checker
|
self.safety_checker = safety_checker
|
||||||
|
attention_maps_images = []
|
||||||
|
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||||
make_image = self.get_make_image(
|
make_image = self.get_make_image(
|
||||||
prompt,
|
prompt,
|
||||||
sampler = sampler,
|
sampler = sampler,
|
||||||
@ -74,6 +78,7 @@ class Generator:
|
|||||||
step_callback = step_callback,
|
step_callback = step_callback,
|
||||||
threshold = threshold,
|
threshold = threshold,
|
||||||
perlin = perlin,
|
perlin = perlin,
|
||||||
|
attention_maps_callback = attention_maps_callback,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
results = []
|
results = []
|
||||||
@ -109,7 +114,7 @@ class Generator:
|
|||||||
results.append([image, seed])
|
results.append([image, seed])
|
||||||
|
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
image_callback(image, seed, first_seed=first_seed)
|
image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_images[-1])
|
||||||
|
|
||||||
seed = self.new_seed()
|
seed = self.new_seed()
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ class Txt2Img(Generator):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
|
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
|
||||||
|
attention_maps_callback=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
@ -39,6 +40,7 @@ class Txt2Img(Generator):
|
|||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
# TODO: eta = ddim_eta,
|
# TODO: eta = ddim_eta,
|
||||||
# TODO: threshold = threshold,
|
# TODO: threshold = threshold,
|
||||||
|
attention_maps_callback = attention_maps_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||||
|
@ -3,7 +3,7 @@ from typing import Union, Optional
|
|||||||
import re
|
import re
|
||||||
import pyparsing as pp
|
import pyparsing as pp
|
||||||
'''
|
'''
|
||||||
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
|
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
|
||||||
weighted subprompts.
|
weighted subprompts.
|
||||||
|
|
||||||
Useful class exports:
|
Useful class exports:
|
||||||
@ -69,6 +69,12 @@ class FlattenedPrompt():
|
|||||||
return len(self.children) == 0 or \
|
return len(self.children) == 0 or \
|
||||||
(len(self.children) == 1 and len(self.children[0].text) == 0)
|
(len(self.children) == 1 and len(self.children[0].text) == 0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wants_cross_attention_control(self):
|
||||||
|
return any(
|
||||||
|
[issubclass(type(x), CrossAttentionControlledFragment) for x in self.children]
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"FlattenedPrompt:{self.children}"
|
return f"FlattenedPrompt:{self.children}"
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
@ -240,6 +246,12 @@ class Blend():
|
|||||||
self.weights = weights
|
self.weights = weights
|
||||||
self.normalize_weights = normalize_weights
|
self.normalize_weights = normalize_weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wants_cross_attention_control(self):
|
||||||
|
# blends cannot cross-attention control
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
|
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
@ -277,8 +289,8 @@ class PromptParser():
|
|||||||
|
|
||||||
return self.flatten(root[0])
|
return self.flatten(root[0])
|
||||||
|
|
||||||
def parse_legacy_blend(self, text: str) -> Optional[Blend]:
|
def parse_legacy_blend(self, text: str, skip_normalize: bool) -> Optional[Blend]:
|
||||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False)
|
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||||
if len(weighted_subprompts) <= 1:
|
if len(weighted_subprompts) <= 1:
|
||||||
return None
|
return None
|
||||||
strings = [x[0] for x in weighted_subprompts]
|
strings = [x[0] for x in weighted_subprompts]
|
||||||
@ -287,7 +299,7 @@ class PromptParser():
|
|||||||
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
|
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
|
||||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||||
|
|
||||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
|
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
|
||||||
|
|
||||||
|
|
||||||
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
|
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
|
||||||
@ -641,27 +653,3 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
|
|||||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||||
|
|
||||||
|
|
||||||
# shows how the prompt is tokenized
|
|
||||||
# usually tokens have '</w>' to indicate end-of-word,
|
|
||||||
# but for readability it has been replaced with ' '
|
|
||||||
def log_tokenization(text, model, display_label=None):
|
|
||||||
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
|
||||||
tokenized = ""
|
|
||||||
discarded = ""
|
|
||||||
usedTokens = 0
|
|
||||||
totalTokens = len(tokens)
|
|
||||||
for i in range(0, totalTokens):
|
|
||||||
token = tokens[i].replace('</w>', ' ')
|
|
||||||
# alternate color
|
|
||||||
s = (usedTokens % 6) + 1
|
|
||||||
if i < model.cond_stage_model.max_length:
|
|
||||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
|
||||||
usedTokens += 1
|
|
||||||
else: # over max token length
|
|
||||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
|
||||||
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
|
|
||||||
if discarded != "":
|
|
||||||
print(
|
|
||||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
|
||||||
)
|
|
||||||
|
@ -53,7 +53,6 @@ COMMANDS = (
|
|||||||
'--codeformer_fidelity','-cf',
|
'--codeformer_fidelity','-cf',
|
||||||
'--upscale','-U',
|
'--upscale','-U',
|
||||||
'-save_orig','--save_original',
|
'-save_orig','--save_original',
|
||||||
'--skip_normalize','-x',
|
|
||||||
'--log_tokenization','-t',
|
'--log_tokenization','-t',
|
||||||
'--hires_fix',
|
'--hires_fix',
|
||||||
'--inpaint_replace','-r',
|
'--inpaint_replace','-r',
|
||||||
@ -117,19 +116,19 @@ class Completer(object):
|
|||||||
# extensions defined, so go directly into path completion mode
|
# extensions defined, so go directly into path completion mode
|
||||||
if self.extensions is not None:
|
if self.extensions is not None:
|
||||||
self.matches = self._path_completions(text, state, self.extensions)
|
self.matches = self._path_completions(text, state, self.extensions)
|
||||||
|
|
||||||
# looking for an image file
|
# looking for an image file
|
||||||
elif re.search(path_regexp,buffer):
|
elif re.search(path_regexp,buffer):
|
||||||
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
|
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
|
||||||
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
|
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
|
||||||
|
|
||||||
# looking for a seed
|
# looking for a seed
|
||||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||||
self.matches= self._seed_completions(text,state)
|
self.matches= self._seed_completions(text,state)
|
||||||
|
|
||||||
elif re.search('<[\w-]*$',buffer):
|
elif re.search('<[\w-]*$',buffer):
|
||||||
self.matches= self._concept_completions(text,state)
|
self.matches= self._concept_completions(text,state)
|
||||||
|
|
||||||
# looking for a model
|
# looking for a model
|
||||||
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||||
self.matches= self._model_completions(text, state)
|
self.matches= self._model_completions(text, state)
|
||||||
@ -227,7 +226,7 @@ class Completer(object):
|
|||||||
if h_len < 1:
|
if h_len < 1:
|
||||||
print('<empty history>')
|
print('<empty history>')
|
||||||
return
|
return
|
||||||
|
|
||||||
for i in range(0,h_len):
|
for i in range(0,h_len):
|
||||||
line = self.get_history_item(i+1)
|
line = self.get_history_item(i+1)
|
||||||
if match and match not in line:
|
if match and match not in line:
|
||||||
@ -367,7 +366,7 @@ class DummyCompleter(Completer):
|
|||||||
def __init__(self,options):
|
def __init__(self,options):
|
||||||
super().__init__(options)
|
super().__init__(options)
|
||||||
self.history = list()
|
self.history = list()
|
||||||
|
|
||||||
def add_history(self,line):
|
def add_history(self,line):
|
||||||
self.history.append(line)
|
self.history.append(line)
|
||||||
|
|
||||||
|
@ -7,11 +7,9 @@ import torch
|
|||||||
import diffusers
|
import diffusers
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
# adapted from bloc97's CrossAttentionControl colab
|
# adapted from bloc97's CrossAttentionControl colab
|
||||||
# https://github.com/bloc97/CrossAttentionControl
|
# https://github.com/bloc97/CrossAttentionControl
|
||||||
|
|
||||||
|
|
||||||
class Arguments:
|
class Arguments:
|
||||||
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
||||||
"""
|
"""
|
||||||
@ -68,11 +66,11 @@ class Context:
|
|||||||
self.clear_requests(cleanup=True)
|
self.clear_requests(cleanup=True)
|
||||||
|
|
||||||
def register_cross_attention_modules(self, model):
|
def register_cross_attention_modules(self, model):
|
||||||
for name,module in get_attention_modules(model, CrossAttentionType.SELF):
|
for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||||
if name in self.self_cross_attention_module_identifiers:
|
if name in self.self_cross_attention_module_identifiers:
|
||||||
assert False, f"name {name} cannot appear more than once"
|
assert False, f"name {name} cannot appear more than once"
|
||||||
self.self_cross_attention_module_identifiers.append(name)
|
self.self_cross_attention_module_identifiers.append(name)
|
||||||
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS):
|
for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||||
if name in self.tokens_cross_attention_module_identifiers:
|
if name in self.tokens_cross_attention_module_identifiers:
|
||||||
assert False, f"name {name} cannot appear more than once"
|
assert False, f"name {name} cannot appear more than once"
|
||||||
self.tokens_cross_attention_module_identifiers.append(name)
|
self.tokens_cross_attention_module_identifiers.append(name)
|
||||||
@ -175,6 +173,135 @@ class Context:
|
|||||||
map_dict[offset] = slice.to('cpu')
|
map_dict[offset] = slice.to('cpu')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeAICrossAttentionMixin:
|
||||||
|
"""
|
||||||
|
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
||||||
|
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||||
|
and dymamic slicing strategy selection.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||||
|
self.attention_slice_wrangler = None
|
||||||
|
self.slicing_strategy_getter = None
|
||||||
|
self.attention_slice_calculated_callback = None
|
||||||
|
|
||||||
|
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
||||||
|
'''
|
||||||
|
Set custom attention calculator to be called when attention is calculated
|
||||||
|
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||||
|
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||||
|
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||||
|
`suggested_attention_slice` is the default-calculated attention slice
|
||||||
|
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||||
|
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||||
|
|
||||||
|
Pass None to use the default attention calculation.
|
||||||
|
:return:
|
||||||
|
'''
|
||||||
|
self.attention_slice_wrangler = wrangler
|
||||||
|
|
||||||
|
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
||||||
|
self.slicing_strategy_getter = getter
|
||||||
|
|
||||||
|
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
||||||
|
self.attention_slice_calculated_callback = callback
|
||||||
|
|
||||||
|
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||||
|
# calculate attention scores
|
||||||
|
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
||||||
|
attention_scores = torch.baddbmm(
|
||||||
|
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||||
|
query,
|
||||||
|
key.transpose(-1, -2),
|
||||||
|
beta=0,
|
||||||
|
alpha=self.scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate attention slice by taking the best scores for each latent pixel
|
||||||
|
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||||
|
attention_slice_wrangler = self.attention_slice_wrangler
|
||||||
|
if attention_slice_wrangler is not None:
|
||||||
|
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||||
|
else:
|
||||||
|
attention_slice = default_attention_slice
|
||||||
|
|
||||||
|
if self.attention_slice_calculated_callback is not None:
|
||||||
|
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
||||||
|
|
||||||
|
hidden_states = torch.bmm(attention_slice, value)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||||
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
for i in range(0, q.shape[0], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||||
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def einsum_op_mps_v1(self, q, k, v):
|
||||||
|
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||||
|
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||||
|
else:
|
||||||
|
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||||
|
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||||
|
|
||||||
|
def einsum_op_mps_v2(self, q, k, v):
|
||||||
|
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||||
|
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||||
|
else:
|
||||||
|
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||||
|
|
||||||
|
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||||
|
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||||
|
if size_mb <= max_tensor_mb:
|
||||||
|
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||||
|
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||||
|
if div <= q.shape[0]:
|
||||||
|
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||||
|
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||||
|
|
||||||
|
def einsum_op_cuda(self, q, k, v):
|
||||||
|
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||||
|
slicing_strategy_getter = self.slicing_strategy_getter
|
||||||
|
if slicing_strategy_getter is not None:
|
||||||
|
(dim, slice_size) = slicing_strategy_getter(self)
|
||||||
|
if dim is not None:
|
||||||
|
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||||
|
if dim == 0:
|
||||||
|
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||||
|
elif dim == 1:
|
||||||
|
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||||
|
|
||||||
|
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||||
|
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
|
||||||
|
# Divide factor of safety as there's copying and fragmentation
|
||||||
|
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||||
|
|
||||||
|
|
||||||
|
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
||||||
|
if q.device.type == 'cuda':
|
||||||
|
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||||
|
return self.einsum_op_cuda(q, k, v)
|
||||||
|
|
||||||
|
if q.device.type == 'mps' or q.device.type == 'cpu':
|
||||||
|
if self.mem_total_gb >= 32:
|
||||||
|
return self.einsum_op_mps_v1(q, k, v)
|
||||||
|
return self.einsum_op_mps_v2(q, k, v)
|
||||||
|
|
||||||
|
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||||
|
# Tested on i7 with 8MB L3 cache.
|
||||||
|
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def remove_cross_attention_control(model):
|
def remove_cross_attention_control(model):
|
||||||
remove_attention_function(model)
|
remove_attention_function(model)
|
||||||
|
|
||||||
@ -210,8 +337,7 @@ def setup_cross_attention_control(model, context: Context):
|
|||||||
inject_attention_function(model, context)
|
inject_attention_function(model, context)
|
||||||
|
|
||||||
|
|
||||||
def get_attention_modules(model, which: CrossAttentionType):
|
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||||
# cross_attention_class: type = ldm.modules.attention.CrossAttention
|
|
||||||
cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
||||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||||
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
||||||
@ -220,7 +346,7 @@ def get_attention_modules(model, which: CrossAttentionType):
|
|||||||
expected_count = 16
|
expected_count = 16
|
||||||
if cross_attention_modules_in_model_count != expected_count:
|
if cross_attention_modules_in_model_count != expected_count:
|
||||||
# non-fatal error but .swap() won't work.
|
# non-fatal error but .swap() won't work.
|
||||||
print(f"Error! CrossAttentionControl found an unexpected number of InvokeAICrossAttention modules in the model " +
|
print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " +
|
||||||
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " +
|
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " +
|
||||||
f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " +
|
f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " +
|
||||||
f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " +
|
f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " +
|
||||||
@ -266,7 +392,7 @@ def inject_attention_function(unet, context: Context):
|
|||||||
|
|
||||||
return attention_slice
|
return attention_slice
|
||||||
|
|
||||||
cross_attention_modules = get_attention_modules(unet, CrossAttentionType.TOKENS) + get_attention_modules(unet, CrossAttentionType.SELF)
|
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||||
for identifier, module in cross_attention_modules:
|
for identifier, module in cross_attention_modules:
|
||||||
module.identifier = identifier
|
module.identifier = identifier
|
||||||
try:
|
try:
|
||||||
@ -282,7 +408,7 @@ def inject_attention_function(unet, context: Context):
|
|||||||
|
|
||||||
|
|
||||||
def remove_attention_function(unet):
|
def remove_attention_function(unet):
|
||||||
cross_attention_modules = get_attention_modules(unet, CrossAttentionType.TOKENS) + get_attention_modules(unet, CrossAttentionType.SELF)
|
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||||
for identifier, module in cross_attention_modules:
|
for identifier, module in cross_attention_modules:
|
||||||
try:
|
try:
|
||||||
# clear wrangler callback
|
# clear wrangler callback
|
||||||
@ -316,7 +442,6 @@ def get_mem_free_total(device):
|
|||||||
return mem_free_total
|
return mem_free_total
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAICrossAttentionMixin:
|
class InvokeAICrossAttentionMixin:
|
||||||
"""
|
"""
|
||||||
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
||||||
@ -452,4 +577,3 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention,
|
|||||||
hidden_states = self.reshape_batch_dim_to_heads(damian_result)
|
hidden_states = self.reshape_batch_dim_to_heads(damian_result)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
95
ldm/models/diffusion/cross_attention_map_saving.py
Normal file
95
ldm/models/diffusion/cross_attention_map_saving.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode
|
||||||
|
|
||||||
|
from ldm.models.diffusion.cross_attention_control import get_cross_attention_modules, CrossAttentionType
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionMapSaver():
|
||||||
|
|
||||||
|
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
||||||
|
self.token_ids = token_ids
|
||||||
|
self.latents_shape = latents_shape
|
||||||
|
#self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||||
|
self.collated_maps = {}
|
||||||
|
|
||||||
|
def clear_maps(self):
|
||||||
|
self.collated_maps = {}
|
||||||
|
|
||||||
|
def add_attention_maps(self, maps: torch.Tensor, key: str):
|
||||||
|
"""
|
||||||
|
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
|
||||||
|
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
|
||||||
|
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
key_and_size = f'{key}_{maps.shape[1]}'
|
||||||
|
|
||||||
|
# extract desired tokens
|
||||||
|
maps = maps[:, :, self.token_ids]
|
||||||
|
|
||||||
|
# merge attention heads to a single map per token
|
||||||
|
maps = torch.sum(maps, 0)
|
||||||
|
|
||||||
|
# store
|
||||||
|
if key_and_size not in self.collated_maps:
|
||||||
|
self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu')
|
||||||
|
self.collated_maps[key_and_size] += maps.cpu()
|
||||||
|
|
||||||
|
def write_maps_to_disk(self, path: str):
|
||||||
|
pil_image = self.get_stacked_maps_image()
|
||||||
|
pil_image.save(path, 'PNG')
|
||||||
|
|
||||||
|
def get_stacked_maps_image(self) -> PIL.Image:
|
||||||
|
"""
|
||||||
|
Scale all collected attention maps to the same size, blend them together and return as an image.
|
||||||
|
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
||||||
|
"""
|
||||||
|
num_tokens = len(self.token_ids)
|
||||||
|
if num_tokens == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
latents_height = self.latents_shape[0]
|
||||||
|
latents_width = self.latents_shape[1]
|
||||||
|
|
||||||
|
merged = None
|
||||||
|
|
||||||
|
for key, maps in self.collated_maps.items():
|
||||||
|
|
||||||
|
# maps has shape [(H*W), N] for N tokens
|
||||||
|
# but we want [N, H, W]
|
||||||
|
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||||
|
this_maps_height = int(float(latents_height) * this_scale_factor)
|
||||||
|
this_maps_width = int(float(latents_width) * this_scale_factor)
|
||||||
|
# and we need to do some dimension juggling
|
||||||
|
maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width])
|
||||||
|
|
||||||
|
# scale to output size if necessary
|
||||||
|
if this_scale_factor != 1:
|
||||||
|
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
||||||
|
|
||||||
|
# normalize
|
||||||
|
maps_min = torch.min(maps)
|
||||||
|
maps_range = torch.max(maps) - maps_min
|
||||||
|
#print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
||||||
|
maps_normalized = (maps - maps_min) / maps_range
|
||||||
|
# expand to (-0.1, 1.1) and clamp
|
||||||
|
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
||||||
|
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
||||||
|
|
||||||
|
# merge together, producing a vertical stack
|
||||||
|
maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width])
|
||||||
|
|
||||||
|
if merged is None:
|
||||||
|
merged = maps_stacked
|
||||||
|
else:
|
||||||
|
# screen blend
|
||||||
|
merged = 1 - (1 - maps_stacked)*(1 - merged)
|
||||||
|
|
||||||
|
if merged is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
merged_bytes = merged.mul(0xff).byte()
|
||||||
|
return PIL.Image.fromarray(merged_bytes.numpy(), mode='L')
|
@ -4,6 +4,7 @@ import k_diffusion as K
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from .cross_attention_map_saving import AttentionMapSaver
|
||||||
from .sampler import Sampler
|
from .sampler import Sampler
|
||||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
|
||||||
@ -36,6 +37,7 @@ class CFGDenoiser(nn.Module):
|
|||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
||||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||||
|
|
||||||
|
|
||||||
def prepare_to_sample(self, t_enc, **kwargs):
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
|
|
||||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||||
@ -106,12 +108,12 @@ class KSampler(Sampler):
|
|||||||
else:
|
else:
|
||||||
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
|
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
|
||||||
self.sigmas = self.karras_sigmas
|
self.sigmas = self.karras_sigmas
|
||||||
|
|
||||||
# ALERT: We are completely overriding the sample() method in the base class, which
|
# ALERT: We are completely overriding the sample() method in the base class, which
|
||||||
# means that inpainting will not work. To get this to work we need to be able to
|
# means that inpainting will not work. To get this to work we need to be able to
|
||||||
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
|
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
|
||||||
# in the lstein/k-diffusion branch.
|
# in the lstein/k-diffusion branch.
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
@ -145,7 +147,7 @@ class KSampler(Sampler):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||||
return x0
|
return x0
|
||||||
|
|
||||||
# Most of these arguments are ignored and are only present for compatibility with
|
# Most of these arguments are ignored and are only present for compatibility with
|
||||||
# other samples
|
# other samples
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -158,6 +160,7 @@ class KSampler(Sampler):
|
|||||||
callback=None,
|
callback=None,
|
||||||
normals_sequence=None,
|
normals_sequence=None,
|
||||||
img_callback=None,
|
img_callback=None,
|
||||||
|
attention_maps_callback=None,
|
||||||
quantize_x0=False,
|
quantize_x0=False,
|
||||||
eta=0.0,
|
eta=0.0,
|
||||||
mask=None,
|
mask=None,
|
||||||
@ -171,7 +174,7 @@ class KSampler(Sampler):
|
|||||||
log_every_t=100,
|
log_every_t=100,
|
||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
extra_conditioning_info=None,
|
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
||||||
threshold = 0,
|
threshold = 0,
|
||||||
perlin = 0,
|
perlin = 0,
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
@ -204,6 +207,12 @@ class KSampler(Sampler):
|
|||||||
|
|
||||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||||
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||||
|
|
||||||
|
attention_map_token_ids = range(1, extra_conditioning_info.tokens_count_including_eos_bos - 1)
|
||||||
|
attention_maps_saver = None if attention_maps_callback is None else AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
||||||
|
if attention_maps_callback is not None:
|
||||||
|
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
|
||||||
|
|
||||||
extra_args = {
|
extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
@ -217,6 +226,8 @@ class KSampler(Sampler):
|
|||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
if attention_maps_callback is not None:
|
||||||
|
attention_maps_callback(attention_maps_saver)
|
||||||
return sampling_result
|
return sampling_result
|
||||||
|
|
||||||
# this code will support inpainting if and when ksampler API modified or
|
# this code will support inpainting if and when ksampler API modified or
|
||||||
@ -248,7 +259,7 @@ class KSampler(Sampler):
|
|||||||
# terrible, confusing names here
|
# terrible, confusing names here
|
||||||
steps = self.ddim_num_steps
|
steps = self.ddim_num_steps
|
||||||
t_enc = self.t_enc
|
t_enc = self.t_enc
|
||||||
|
|
||||||
# sigmas is a full steps in length, but t_enc might
|
# sigmas is a full steps in length, but t_enc might
|
||||||
# be less. We start in the middle of the sigma array
|
# be less. We start in the middle of the sigma array
|
||||||
# and work our way to the end after t_enc steps.
|
# and work our way to the end after t_enc steps.
|
||||||
@ -280,7 +291,7 @@ class KSampler(Sampler):
|
|||||||
return x_T + x
|
return x_T + x
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def prepare_to_sample(self,t_enc,**kwargs):
|
def prepare_to_sample(self,t_enc,**kwargs):
|
||||||
self.t_enc = t_enc
|
self.t_enc = t_enc
|
||||||
self.model_wrap = None
|
self.model_wrap = None
|
||||||
|
@ -5,8 +5,8 @@ from typing import Callable, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||||
remove_cross_attention_control, setup_cross_attention_control, Context
|
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, CrossAttentionType
|
||||||
from ldm.modules.attention import get_mem_free_total
|
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffuserComponent:
|
class InvokeAIDiffuserComponent:
|
||||||
@ -21,7 +21,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
|
|
||||||
class ExtraConditioningInfo:
|
class ExtraConditioningInfo:
|
||||||
def __init__(self, cross_attention_control_args: Optional[Arguments]):
|
def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]):
|
||||||
|
self.tokens_count_including_eos_bos = tokens_count_including_eos_bos
|
||||||
self.cross_attention_control_args = cross_attention_control_args
|
self.cross_attention_control_args = cross_attention_control_args
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -53,7 +54,25 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
remove_cross_attention_control(self.model)
|
remove_cross_attention_control(self.model)
|
||||||
|
|
||||||
|
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||||
|
def callback(slice, dim, offset, slice_size, key):
|
||||||
|
if dim is not None:
|
||||||
|
# sliced tokens attention map saving is not implemented
|
||||||
|
return
|
||||||
|
saver.add_attention_maps(slice, key)
|
||||||
|
|
||||||
|
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||||
|
for identifier, module in tokens_cross_attention_modules:
|
||||||
|
key = ('down' if identifier.startswith('down') else
|
||||||
|
'up' if identifier.startswith('up') else
|
||||||
|
'mid')
|
||||||
|
module.set_attention_slice_calculated_callback(
|
||||||
|
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key))
|
||||||
|
|
||||||
|
def remove_attention_map_saving(self):
|
||||||
|
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||||
|
for _, module in tokens_cross_attention_modules:
|
||||||
|
module.set_attention_slice_calculated_callback(None)
|
||||||
|
|
||||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||||
unconditioning: Union[torch.Tensor,dict],
|
unconditioning: Union[torch.Tensor,dict],
|
||||||
|
@ -10,8 +10,6 @@ from einops import rearrange, repeat
|
|||||||
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin
|
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin
|
||||||
from ldm.modules.diffusionmodules.util import checkpoint
|
from ldm.modules.diffusionmodules.util import checkpoint
|
||||||
|
|
||||||
import psutil
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
@ -165,10 +163,10 @@ def get_mem_free_total(device):
|
|||||||
return mem_free_total
|
return mem_free_total
|
||||||
|
|
||||||
class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
||||||
|
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||||
print(f"Warning! ldm.modules.attention.CrossAttention is no longer being maintained. Please use InvokeAICrossAttention instead.")
|
print(f"Warning! ldm.modules.attention.CrossAttention is no longer being maintained. Please use InvokeAICrossAttention instead.")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
InvokeAICrossAttentionMixin.__init__(self)
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
|
|
||||||
@ -184,7 +182,6 @@ class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
|||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
@ -196,7 +193,7 @@ class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
|||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
# prevent scale being applied twice
|
# don't apply scale twice
|
||||||
cached_scale = self.scale
|
cached_scale = self.scale
|
||||||
self.scale = 1
|
self.scale = 1
|
||||||
r = self.get_invokeai_attention_mem_efficient(q, k, v)
|
r = self.get_invokeai_attention_mem_efficient(q, k, v)
|
||||||
|
Loading…
Reference in New Issue
Block a user