Merge branch 'development' into sync-dev-with-main

This commit is contained in:
Lincoln Stein 2022-11-13 21:51:17 +00:00
commit f04d1bab21
10 changed files with 271 additions and 266 deletions

View File

@ -65,14 +65,11 @@ requests. Be sure to use the provided templates. They will help aid diagnose iss
### Installation ### Installation
This fork is supported across multiple platforms. You can find individual installation instructions This fork is supported across Linux, Windows and Macintosh. Linux
below. users can use either an Nvidia-based card (with CUDA support) or an
AMD card (using the ROCm driver). For full installation and upgrade
- #### [Linux](https://invoke-ai.github.io/InvokeAI/installation/INSTALL_LINUX/) instructions, please see:
[InvokeAI Installation Overview](https://invoke-ai.github.io/InvokeAI/installation/)
- #### [Windows](https://invoke-ai.github.io/InvokeAI/installation/INSTALL_WINDOWS/)
- #### [Macintosh](https://invoke-ai.github.io/InvokeAI/installation/INSTALL_MAC/)
### Hardware Requirements ### Hardware Requirements

View File

@ -30,7 +30,7 @@ model:
target: ldm.modules.embedding_manager.EmbeddingManager target: ldm.modules.embedding_manager.EmbeddingManager
params: params:
placeholder_strings: ["*"] placeholder_strings: ["*"]
initializer_words: ['face', 'man', 'photo', 'africanmale'] initializer_words: ['sculpture']
per_image_tokens: false per_image_tokens: false
num_vectors_per_token: 1 num_vectors_per_token: 1
progressive_words: False progressive_words: False

View File

@ -30,7 +30,7 @@ model:
target: ldm.modules.embedding_manager.EmbeddingManager target: ldm.modules.embedding_manager.EmbeddingManager
params: params:
placeholder_strings: ["*"] placeholder_strings: ["*"]
initializer_words: ['face', 'man', 'photo', 'africanmale'] initializer_words: ['sculpture']
per_image_tokens: false per_image_tokens: false
num_vectors_per_token: 1 num_vectors_per_token: 1
progressive_words: False progressive_words: False

View File

@ -22,7 +22,7 @@ model:
target: ldm.modules.embedding_manager.EmbeddingManager target: ldm.modules.embedding_manager.EmbeddingManager
params: params:
placeholder_strings: ["*"] placeholder_strings: ["*"]
initializer_words: ['face', 'man', 'photo', 'africanmale'] initializer_words: ['sculpture']
per_image_tokens: false per_image_tokens: false
num_vectors_per_token: 6 num_vectors_per_token: 6
progressive_words: False progressive_words: False

View File

@ -80,12 +80,11 @@ Mac and Linux machines, and runs on GPU cards with as little as 4 GB or RAM.
## :octicons-package-dependencies-24: Installation ## :octicons-package-dependencies-24: Installation
This fork is supported across multiple platforms. You can find individual This fork is supported across Linux, Windows and Macintosh. Linux
installation instructions below. users can use either an Nvidia-based card (with CUDA support) or an
AMD card (using the ROCm driver). For full installation and upgrade
- :fontawesome-brands-linux: [Linux](installation/INSTALL_LINUX.md) instructions, please see:
- :fontawesome-brands-windows: [Windows](installation/INSTALL_WINDOWS.md) [InvokeAI Installation Overview](https://invoke-ai.github.io/InvokeAI/installation/)
- :fontawesome-brands-apple: [Macintosh](installation/INSTALL_MAC.md)
## :fontawesome-solid-computer: Hardware Requirements ## :fontawesome-solid-computer: Hardware Requirements
@ -123,6 +122,14 @@ You wil need one of the following:
## :octicons-log-16: Latest Changes ## :octicons-log-16: Latest Changes
### v2.1.3 <small>(13 November 2022)</small>
- A choice of installer scripts that automate installation and configuration. See [Installation](https://github.com/invoke-ai/InvokeAI/blob/2.1.3-rc6/docs/installation/INSTALL.md).
- A streamlined manual installation process that works for both Conda and PIP-only installs. See [Manual Installation](https://github.com/invoke-ai/InvokeAI/blob/2.1.3-rc6/docs/installation/INSTALL_MANUAL.md).
- The ability to save frequently-used startup options (model to load, steps, sampler, etc) in a `.invokeai` file. See [Client](https://github.com/invoke-ai/InvokeAI/blob/2.1.3-rc6/docs/features/CLI.md)
- Support for AMD GPU cards (non-CUDA) on Linux machines.
- Multiple bugs and edge cases squashed.
### v2.1.0 <small>(2 November 2022)</small> ### v2.1.0 <small>(2 November 2022)</small>
- [Inpainting](https://invoke-ai.github.io/InvokeAI/features/INPAINTING/) - [Inpainting](https://invoke-ai.github.io/InvokeAI/features/INPAINTING/)

View File

@ -19,7 +19,7 @@ set INSTALL_ENV_DIR=%cd%\installer_files\env
@rem https://mamba.readthedocs.io/en/latest/installation.html @rem https://mamba.readthedocs.io/en/latest/installation.html
set MICROMAMBA_DOWNLOAD_URL=https://github.com/cmdr2/stable-diffusion-ui/releases/download/v1.1/micromamba.exe set MICROMAMBA_DOWNLOAD_URL=https://github.com/cmdr2/stable-diffusion-ui/releases/download/v1.1/micromamba.exe
set RELEASE_URL=https://github.com/invoke-ai/InvokeAI set RELEASE_URL=https://github.com/invoke-ai/InvokeAI
set RELEASE_SOURCEBALL=/archive/refs/heads/v2.1.3.tar.gz set RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
set PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download set PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
set PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-x86_64-pc-windows-msvc-shared-install_only.tar.gz set PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-x86_64-pc-windows-msvc-shared-install_only.tar.gz

View File

@ -74,7 +74,7 @@ fi
INSTALL_ENV_DIR="$(pwd)/installer_files/env" INSTALL_ENV_DIR="$(pwd)/installer_files/env"
MICROMAMBA_DOWNLOAD_URL="https://micro.mamba.pm/api/micromamba/${MAMBA_OS_NAME}-${MAMBA_ARCH}/latest" MICROMAMBA_DOWNLOAD_URL="https://micro.mamba.pm/api/micromamba/${MAMBA_OS_NAME}-${MAMBA_ARCH}/latest"
RELEASE_URL=https://github.com/invoke-ai/InvokeAI RELEASE_URL=https://github.com/invoke-ai/InvokeAI
RELEASE_SOURCEBALL=/archive/refs/heads/v2.1.3.tar.gz RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
if [ "$OS_NAME" == "darwin" ]; then if [ "$OS_NAME" == "darwin" ]; then
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-apple-darwin-install_only.tar.gz PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-apple-darwin-install_only.tar.gz

View File

@ -14,7 +14,7 @@ import torch
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
from ..models.diffusion.cross_attention_control import CrossAttentionControl from ..models.diffusion import cross_attention_control
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
@ -49,7 +49,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0] parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
conditioning = None conditioning = None
cac_args:CrossAttentionControl.Arguments = None cac_args:cross_attention_control.Arguments = None
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt blend: Blend = parsed_prompt
@ -120,7 +120,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
conditioning = original_embeddings conditioning = original_embeddings
edited_conditioning = edited_embeddings edited_conditioning = edited_embeddings
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options) #print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = CrossAttentionControl.Arguments( cac_args = cross_attention_control.Arguments(
edited_conditioning = edited_conditioning, edited_conditioning = edited_conditioning,
edit_opcodes = edit_opcodes, edit_opcodes = edit_opcodes,
edit_options = edit_options edit_options = edit_options

View File

@ -7,256 +7,256 @@ import torch
# https://github.com/bloc97/CrossAttentionControl # https://github.com/bloc97/CrossAttentionControl
class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
"""
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768]
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required)
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes.
"""
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
self.edited_conditioning = edited_conditioning
self.edit_opcodes = edit_opcodes
class CrossAttentionControl: if edited_conditioning is not None:
assert len(edit_opcodes) == len(edit_options), \
class Arguments: "there must be 1 edit_options dict for each edit_opcodes tuple"
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict): non_none_edit_options = [x for x in edit_options if x is not None]
""" assert len(non_none_edit_options)>0, "missing edit_options"
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768] if len(non_none_edit_options)>1:
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required) print('warning: cross-attention control options are not working properly for >1 edit')
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes. self.edit_options = non_none_edit_options[0]
"""
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
self.edited_conditioning = edited_conditioning
self.edit_opcodes = edit_opcodes
if edited_conditioning is not None:
assert len(edit_opcodes) == len(edit_options), \
"there must be 1 edit_options dict for each edit_opcodes tuple"
non_none_edit_options = [x for x in edit_options if x is not None]
assert len(non_none_edit_options)>0, "missing edit_options"
if len(non_none_edit_options)>1:
print('warning: cross-attention control options are not working properly for >1 edit')
self.edit_options = non_none_edit_options[0]
class Context: class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
class Action(enum.Enum):
NONE = 0
SAVE = 1,
APPLY = 2
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int): class Context:
"""
:param arguments: Arguments for the cross-attention control process
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
"""
self.arguments = arguments
self.step_count = step_count
self.self_cross_attention_module_identifiers = [] cross_attention_mask: Optional[torch.Tensor]
self.tokens_cross_attention_module_identifiers = [] cross_attention_index_map: Optional[torch.Tensor]
class Action(enum.Enum):
NONE = 0
SAVE = 1,
APPLY = 2
def __init__(self, arguments: Arguments, step_count: int):
"""
:param arguments: Arguments for the cross-attention control process
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
"""
self.cross_attention_mask = None
self.cross_attention_index_map = None
self.self_cross_attention_action = Context.Action.NONE
self.tokens_cross_attention_action = Context.Action.NONE
self.arguments = arguments
self.step_count = step_count
self.self_cross_attention_module_identifiers = []
self.tokens_cross_attention_module_identifiers = []
self.saved_cross_attention_maps = {}
self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model):
for name,module in get_attention_modules(model, CrossAttentionType.SELF):
self.self_cross_attention_module_identifiers.append(name)
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS):
self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.SAVE
else:
self.tokens_cross_attention_action = Context.Action.SAVE
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.APPLY
else:
self.tokens_cross_attention_action = Context.Action.APPLY
def is_tokens_cross_attention(self, module_identifier) -> bool:
return module_identifier in self.tokens_cross_attention_module_identifiers
def get_should_save_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == Context.Action.SAVE
elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == Context.Action.SAVE
return False
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == Context.Action.APPLY
elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == Context.Action.APPLY
return False
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
-> list[CrossAttentionType]:
"""
Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
"""
if percent_through is None:
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
opts = self.arguments.edit_options
to_control = []
if opts['s_start'] <= percent_through < opts['s_end']:
to_control.append(CrossAttentionType.SELF)
if opts['t_start'] <= percent_through < opts['t_end']:
to_control.append(CrossAttentionType.TOKENS)
return to_control
def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int,
slice_size: Optional[int]):
if identifier not in self.saved_cross_attention_maps:
self.saved_cross_attention_maps[identifier] = {
'dim': dim,
'slice_size': slice_size,
'slices': {offset or 0: slice}
}
else:
self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice
def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int):
saved_attention_dict = self.saved_cross_attention_maps[identifier]
if requested_dim is None:
if saved_attention_dict['dim'] is not None:
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
return saved_attention_dict['slices'][0]
if saved_attention_dict['dim'] == requested_dim:
if slice_size != saved_attention_dict['slice_size']:
raise RuntimeError(
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
return saved_attention_dict['slices'][requested_offset]
if saved_attention_dict['dim'] is None:
whole_saved_attention = saved_attention_dict['slices'][0]
if requested_dim == 0:
return whole_saved_attention[requested_offset:requested_offset + slice_size]
elif requested_dim == 1:
return whole_saved_attention[:, requested_offset:requested_offset + slice_size]
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
if saved_attention is None:
return None, None
return saved_attention['dim'], saved_attention['slice_size']
def clear_requests(self, cleanup=True):
self.tokens_cross_attention_action = Context.Action.NONE
self.self_cross_attention_action = Context.Action.NONE
if cleanup:
self.saved_cross_attention_maps = {} self.saved_cross_attention_maps = {}
self.clear_requests(cleanup=True) def offload_saved_attention_slices_to_cpu(self):
for key, map_dict in self.saved_cross_attention_maps.items():
for offset, slice in map_dict['slices'].items():
map_dict[offset] = slice.to('cpu')
def register_cross_attention_modules(self, model):
for name,module in CrossAttentionControl.get_attention_modules(model,
CrossAttentionControl.CrossAttentionType.SELF):
self.self_cross_attention_module_identifiers.append(name)
for name,module in CrossAttentionControl.get_attention_modules(model,
CrossAttentionControl.CrossAttentionType.TOKENS):
self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'): def remove_cross_attention_control(model):
if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF: remove_attention_function(model)
self.self_cross_attention_action = CrossAttentionControl.Context.Action.SAVE
def setup_cross_attention_control(model, context: Context):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
:param model: The unet model to inject into.
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
:return: None
"""
# adapted from init_attention_edit
device = context.arguments.edited_conditioning.device
# urgh. should this be hardcoded?
max_length = 77
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
context.register_cross_attention_modules(model)
context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
inject_attention_function(model, context)
def get_attention_modules(model, which: CrossAttentionType):
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
return [(name,module) for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name]
def inject_attention_function(unet, context: Context):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size):
#memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
attention_slice = suggested_attention_slice
if context.get_should_save_maps(module.identifier):
#print(module.identifier, "saving suggested_attention_slice of shape",
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice
context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size)
elif context.get_should_apply_saved_maps(module.identifier):
#print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
# slice may have been offloaded to CPU
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
if context.is_tokens_cross_attention(module.identifier):
index_map = context.cross_attention_index_map
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
this_attention_slice = suggested_attention_slice
mask = context.cross_attention_mask
saved_mask = mask
this_mask = 1 - mask
attention_slice = remapped_saved_attention_slice * saved_mask + \
this_attention_slice * this_mask
else: else:
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.SAVE # just use everything
attention_slice = saved_attention_slice
def request_apply_saved_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'): return attention_slice
if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF:
self.self_cross_attention_action = CrossAttentionControl.Context.Action.APPLY
else:
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.APPLY
def is_tokens_cross_attention(self, module_identifier) -> bool: for name, module in unet.named_modules():
return module_identifier in self.tokens_cross_attention_module_identifiers module_name = type(module).__name__
if module_name == "CrossAttention":
def get_should_save_maps(self, module_identifier: str) -> bool: module.identifier = name
if module_identifier in self.self_cross_attention_module_identifiers: module.set_attention_slice_wrangler(attention_slice_wrangler)
return self.self_cross_attention_action == CrossAttentionControl.Context.Action.SAVE module.set_slicing_strategy_getter(lambda module, module_identifier=name: \
elif module_identifier in self.tokens_cross_attention_module_identifiers: context.get_slicing_strategy(module_identifier))
return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.SAVE
return False
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == CrossAttentionControl.Context.Action.APPLY
elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.APPLY
return False
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
-> list['CrossAttentionControl.CrossAttentionType']:
"""
Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
"""
if percent_through is None:
return [CrossAttentionControl.CrossAttentionType.SELF, CrossAttentionControl.CrossAttentionType.TOKENS]
opts = self.arguments.edit_options
to_control = []
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
to_control.append(CrossAttentionControl.CrossAttentionType.SELF)
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
to_control.append(CrossAttentionControl.CrossAttentionType.TOKENS)
return to_control
def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int,
slice_size: Optional[int]):
if identifier not in self.saved_cross_attention_maps:
self.saved_cross_attention_maps[identifier] = {
'dim': dim,
'slice_size': slice_size,
'slices': {offset or 0: slice}
}
else:
self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice
def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int):
saved_attention_dict = self.saved_cross_attention_maps[identifier]
if requested_dim is None:
if saved_attention_dict['dim'] is not None:
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
return saved_attention_dict['slices'][0]
if saved_attention_dict['dim'] == requested_dim:
if slice_size != saved_attention_dict['slice_size']:
raise RuntimeError(
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
return saved_attention_dict['slices'][requested_offset]
if saved_attention_dict['dim'] == None:
whole_saved_attention = saved_attention_dict['slices'][0]
if requested_dim == 0:
return whole_saved_attention[requested_offset:requested_offset + slice_size]
elif requested_dim == 1:
return whole_saved_attention[:, requested_offset:requested_offset + slice_size]
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
def get_slicing_strategy(self, identifier: str) -> Optional[tuple[int, int]]:
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
if saved_attention is None:
return None, None
return saved_attention['dim'], saved_attention['slice_size']
def clear_requests(self, cleanup=True):
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.NONE
self.self_cross_attention_action = CrossAttentionControl.Context.Action.NONE
if cleanup:
self.saved_cross_attention_maps = {}
def offload_saved_attention_slices_to_cpu(self):
for key, map_dict in self.saved_cross_attention_maps.items():
for offset, slice in map_dict['slices'].items():
map_dict[offset] = slice.to('cpu')
@classmethod
def remove_cross_attention_control(cls, model):
cls.remove_attention_function(model)
@classmethod
def setup_cross_attention_control(cls, model, context: Context):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
:param model: The unet model to inject into.
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
:return: None
"""
# adapted from init_attention_edit
device = context.arguments.edited_conditioning.device
# urgh. should this be hardcoded?
max_length = 77
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
context.register_cross_attention_modules(model)
context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
cls.inject_attention_function(model, context)
class CrossAttentionType(enum.Enum): def remove_attention_function(unet):
SELF = 1 # clear wrangler callback
TOKENS = 2 for name, module in unet.named_modules():
module_name = type(module).__name__
@classmethod if module_name == "CrossAttention":
def get_attention_modules(cls, model, which: CrossAttentionType): module.set_attention_slice_wrangler(None)
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2" module.set_slicing_strategy_getter(None)
return [(name,module) for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod
def inject_attention_function(cls, unet, context: 'CrossAttentionControl.Context'):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size):
#memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
attention_slice = suggested_attention_slice
if context.get_should_save_maps(module.identifier):
#print(module.identifier, "saving suggested_attention_slice of shape",
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice
context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size)
elif context.get_should_apply_saved_maps(module.identifier):
#print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
# slice may have been offloaded to CPU
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
if context.is_tokens_cross_attention(module.identifier):
index_map = context.cross_attention_index_map
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
this_attention_slice = suggested_attention_slice
mask = context.cross_attention_mask
saved_mask = mask
this_mask = 1 - mask
attention_slice = remapped_saved_attention_slice * saved_mask + \
this_attention_slice * this_mask
else:
# just use everything
attention_slice = saved_attention_slice
return attention_slice
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.identifier = name
module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \
context.get_slicing_strategy(module_identifier))
@classmethod
def remove_attention_function(cls, unet):
# clear wrangler callback
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None)
module.set_slicing_strategy_getter(None)

View File

@ -4,7 +4,8 @@ from typing import Callable, Optional, Union
import torch import torch
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context
from ldm.modules.attention import get_mem_free_total from ldm.modules.attention import get_mem_free_total
@ -20,7 +21,7 @@ class InvokeAIDiffuserComponent:
class ExtraConditioningInfo: class ExtraConditioningInfo:
def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]): def __init__(self, cross_attention_control_args: Optional[Arguments]):
self.cross_attention_control_args = cross_attention_control_args self.cross_attention_control_args = cross_attention_control_args
@property @property
@ -40,16 +41,16 @@ class InvokeAIDiffuserComponent:
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
self.conditioning = conditioning self.conditioning = conditioning
self.cross_attention_control_context = CrossAttentionControl.Context( self.cross_attention_control_context = Context(
arguments=self.conditioning.cross_attention_control_args, arguments=self.conditioning.cross_attention_control_args,
step_count=step_count step_count=step_count
) )
CrossAttentionControl.setup_cross_attention_control(self.model, self.cross_attention_control_context) setup_cross_attention_control(self.model, self.cross_attention_control_context)
def remove_cross_attention_control(self): def remove_cross_attention_control(self):
self.conditioning = None self.conditioning = None
self.cross_attention_control_context = None self.cross_attention_control_context = None
CrossAttentionControl.remove_cross_attention_control(self.model) remove_cross_attention_control(self.model)
@ -71,7 +72,7 @@ class InvokeAIDiffuserComponent:
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
context: CrossAttentionControl.Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:
percent_through = self.estimate_percent_through(step_index, sigma) percent_through = self.estimate_percent_through(step_index, sigma)
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through) cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
@ -133,7 +134,7 @@ class InvokeAIDiffuserComponent:
# representing batched uncond + cond, but then when it comes to applying the saved attention, the # representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
context:CrossAttentionControl.Context = self.cross_attention_control_context context:Context = self.cross_attention_control_context
try: try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)