mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into sync-dev-with-main
This commit is contained in:
commit
f04d1bab21
13
README.md
13
README.md
@ -65,14 +65,11 @@ requests. Be sure to use the provided templates. They will help aid diagnose iss
|
|||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
This fork is supported across multiple platforms. You can find individual installation instructions
|
This fork is supported across Linux, Windows and Macintosh. Linux
|
||||||
below.
|
users can use either an Nvidia-based card (with CUDA support) or an
|
||||||
|
AMD card (using the ROCm driver). For full installation and upgrade
|
||||||
- #### [Linux](https://invoke-ai.github.io/InvokeAI/installation/INSTALL_LINUX/)
|
instructions, please see:
|
||||||
|
[InvokeAI Installation Overview](https://invoke-ai.github.io/InvokeAI/installation/)
|
||||||
- #### [Windows](https://invoke-ai.github.io/InvokeAI/installation/INSTALL_WINDOWS/)
|
|
||||||
|
|
||||||
- #### [Macintosh](https://invoke-ai.github.io/InvokeAI/installation/INSTALL_MAC/)
|
|
||||||
|
|
||||||
### Hardware Requirements
|
### Hardware Requirements
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ model:
|
|||||||
target: ldm.modules.embedding_manager.EmbeddingManager
|
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||||
params:
|
params:
|
||||||
placeholder_strings: ["*"]
|
placeholder_strings: ["*"]
|
||||||
initializer_words: ['face', 'man', 'photo', 'africanmale']
|
initializer_words: ['sculpture']
|
||||||
per_image_tokens: false
|
per_image_tokens: false
|
||||||
num_vectors_per_token: 1
|
num_vectors_per_token: 1
|
||||||
progressive_words: False
|
progressive_words: False
|
||||||
|
@ -30,7 +30,7 @@ model:
|
|||||||
target: ldm.modules.embedding_manager.EmbeddingManager
|
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||||
params:
|
params:
|
||||||
placeholder_strings: ["*"]
|
placeholder_strings: ["*"]
|
||||||
initializer_words: ['face', 'man', 'photo', 'africanmale']
|
initializer_words: ['sculpture']
|
||||||
per_image_tokens: false
|
per_image_tokens: false
|
||||||
num_vectors_per_token: 1
|
num_vectors_per_token: 1
|
||||||
progressive_words: False
|
progressive_words: False
|
||||||
|
@ -22,7 +22,7 @@ model:
|
|||||||
target: ldm.modules.embedding_manager.EmbeddingManager
|
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||||
params:
|
params:
|
||||||
placeholder_strings: ["*"]
|
placeholder_strings: ["*"]
|
||||||
initializer_words: ['face', 'man', 'photo', 'africanmale']
|
initializer_words: ['sculpture']
|
||||||
per_image_tokens: false
|
per_image_tokens: false
|
||||||
num_vectors_per_token: 6
|
num_vectors_per_token: 6
|
||||||
progressive_words: False
|
progressive_words: False
|
||||||
|
@ -80,12 +80,11 @@ Mac and Linux machines, and runs on GPU cards with as little as 4 GB or RAM.
|
|||||||
|
|
||||||
## :octicons-package-dependencies-24: Installation
|
## :octicons-package-dependencies-24: Installation
|
||||||
|
|
||||||
This fork is supported across multiple platforms. You can find individual
|
This fork is supported across Linux, Windows and Macintosh. Linux
|
||||||
installation instructions below.
|
users can use either an Nvidia-based card (with CUDA support) or an
|
||||||
|
AMD card (using the ROCm driver). For full installation and upgrade
|
||||||
- :fontawesome-brands-linux: [Linux](installation/INSTALL_LINUX.md)
|
instructions, please see:
|
||||||
- :fontawesome-brands-windows: [Windows](installation/INSTALL_WINDOWS.md)
|
[InvokeAI Installation Overview](https://invoke-ai.github.io/InvokeAI/installation/)
|
||||||
- :fontawesome-brands-apple: [Macintosh](installation/INSTALL_MAC.md)
|
|
||||||
|
|
||||||
## :fontawesome-solid-computer: Hardware Requirements
|
## :fontawesome-solid-computer: Hardware Requirements
|
||||||
|
|
||||||
@ -123,6 +122,14 @@ You wil need one of the following:
|
|||||||
|
|
||||||
## :octicons-log-16: Latest Changes
|
## :octicons-log-16: Latest Changes
|
||||||
|
|
||||||
|
### v2.1.3 <small>(13 November 2022)</small>
|
||||||
|
|
||||||
|
- A choice of installer scripts that automate installation and configuration. See [Installation](https://github.com/invoke-ai/InvokeAI/blob/2.1.3-rc6/docs/installation/INSTALL.md).
|
||||||
|
- A streamlined manual installation process that works for both Conda and PIP-only installs. See [Manual Installation](https://github.com/invoke-ai/InvokeAI/blob/2.1.3-rc6/docs/installation/INSTALL_MANUAL.md).
|
||||||
|
- The ability to save frequently-used startup options (model to load, steps, sampler, etc) in a `.invokeai` file. See [Client](https://github.com/invoke-ai/InvokeAI/blob/2.1.3-rc6/docs/features/CLI.md)
|
||||||
|
- Support for AMD GPU cards (non-CUDA) on Linux machines.
|
||||||
|
- Multiple bugs and edge cases squashed.
|
||||||
|
|
||||||
### v2.1.0 <small>(2 November 2022)</small>
|
### v2.1.0 <small>(2 November 2022)</small>
|
||||||
|
|
||||||
- [Inpainting](https://invoke-ai.github.io/InvokeAI/features/INPAINTING/)
|
- [Inpainting](https://invoke-ai.github.io/InvokeAI/features/INPAINTING/)
|
||||||
|
@ -19,7 +19,7 @@ set INSTALL_ENV_DIR=%cd%\installer_files\env
|
|||||||
@rem https://mamba.readthedocs.io/en/latest/installation.html
|
@rem https://mamba.readthedocs.io/en/latest/installation.html
|
||||||
set MICROMAMBA_DOWNLOAD_URL=https://github.com/cmdr2/stable-diffusion-ui/releases/download/v1.1/micromamba.exe
|
set MICROMAMBA_DOWNLOAD_URL=https://github.com/cmdr2/stable-diffusion-ui/releases/download/v1.1/micromamba.exe
|
||||||
set RELEASE_URL=https://github.com/invoke-ai/InvokeAI
|
set RELEASE_URL=https://github.com/invoke-ai/InvokeAI
|
||||||
set RELEASE_SOURCEBALL=/archive/refs/heads/v2.1.3.tar.gz
|
set RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
|
||||||
set PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
|
set PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
|
||||||
set PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-x86_64-pc-windows-msvc-shared-install_only.tar.gz
|
set PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-x86_64-pc-windows-msvc-shared-install_only.tar.gz
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ fi
|
|||||||
INSTALL_ENV_DIR="$(pwd)/installer_files/env"
|
INSTALL_ENV_DIR="$(pwd)/installer_files/env"
|
||||||
MICROMAMBA_DOWNLOAD_URL="https://micro.mamba.pm/api/micromamba/${MAMBA_OS_NAME}-${MAMBA_ARCH}/latest"
|
MICROMAMBA_DOWNLOAD_URL="https://micro.mamba.pm/api/micromamba/${MAMBA_OS_NAME}-${MAMBA_ARCH}/latest"
|
||||||
RELEASE_URL=https://github.com/invoke-ai/InvokeAI
|
RELEASE_URL=https://github.com/invoke-ai/InvokeAI
|
||||||
RELEASE_SOURCEBALL=/archive/refs/heads/v2.1.3.tar.gz
|
RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
|
||||||
PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
|
PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
|
||||||
if [ "$OS_NAME" == "darwin" ]; then
|
if [ "$OS_NAME" == "darwin" ]; then
|
||||||
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-apple-darwin-install_only.tar.gz
|
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-apple-darwin-install_only.tar.gz
|
||||||
|
@ -14,7 +14,7 @@ import torch
|
|||||||
|
|
||||||
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
||||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
|
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
|
||||||
from ..models.diffusion.cross_attention_control import CrossAttentionControl
|
from ..models.diffusion import cross_attention_control
|
||||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||||
|
|
||||||
@ -49,7 +49,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
|
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
|
||||||
|
|
||||||
conditioning = None
|
conditioning = None
|
||||||
cac_args:CrossAttentionControl.Arguments = None
|
cac_args:cross_attention_control.Arguments = None
|
||||||
|
|
||||||
if type(parsed_prompt) is Blend:
|
if type(parsed_prompt) is Blend:
|
||||||
blend: Blend = parsed_prompt
|
blend: Blend = parsed_prompt
|
||||||
@ -120,7 +120,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
conditioning = original_embeddings
|
conditioning = original_embeddings
|
||||||
edited_conditioning = edited_embeddings
|
edited_conditioning = edited_embeddings
|
||||||
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
||||||
cac_args = CrossAttentionControl.Arguments(
|
cac_args = cross_attention_control.Arguments(
|
||||||
edited_conditioning = edited_conditioning,
|
edited_conditioning = edited_conditioning,
|
||||||
edit_opcodes = edit_opcodes,
|
edit_opcodes = edit_opcodes,
|
||||||
edit_options = edit_options
|
edit_options = edit_options
|
||||||
|
@ -7,9 +7,6 @@ import torch
|
|||||||
# https://github.com/bloc97/CrossAttentionControl
|
# https://github.com/bloc97/CrossAttentionControl
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionControl:
|
|
||||||
|
|
||||||
class Arguments:
|
class Arguments:
|
||||||
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
||||||
"""
|
"""
|
||||||
@ -31,18 +28,30 @@ class CrossAttentionControl:
|
|||||||
self.edit_options = non_none_edit_options[0]
|
self.edit_options = non_none_edit_options[0]
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionType(enum.Enum):
|
||||||
|
SELF = 1
|
||||||
|
TOKENS = 2
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
|
|
||||||
|
cross_attention_mask: Optional[torch.Tensor]
|
||||||
|
cross_attention_index_map: Optional[torch.Tensor]
|
||||||
|
|
||||||
class Action(enum.Enum):
|
class Action(enum.Enum):
|
||||||
NONE = 0
|
NONE = 0
|
||||||
SAVE = 1,
|
SAVE = 1,
|
||||||
APPLY = 2
|
APPLY = 2
|
||||||
|
|
||||||
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
|
def __init__(self, arguments: Arguments, step_count: int):
|
||||||
"""
|
"""
|
||||||
:param arguments: Arguments for the cross-attention control process
|
:param arguments: Arguments for the cross-attention control process
|
||||||
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
||||||
"""
|
"""
|
||||||
|
self.cross_attention_mask = None
|
||||||
|
self.cross_attention_index_map = None
|
||||||
|
self.self_cross_attention_action = Context.Action.NONE
|
||||||
|
self.tokens_cross_attention_action = Context.Action.NONE
|
||||||
self.arguments = arguments
|
self.arguments = arguments
|
||||||
self.step_count = step_count
|
self.step_count = step_count
|
||||||
|
|
||||||
@ -54,58 +63,56 @@ class CrossAttentionControl:
|
|||||||
self.clear_requests(cleanup=True)
|
self.clear_requests(cleanup=True)
|
||||||
|
|
||||||
def register_cross_attention_modules(self, model):
|
def register_cross_attention_modules(self, model):
|
||||||
for name,module in CrossAttentionControl.get_attention_modules(model,
|
for name,module in get_attention_modules(model, CrossAttentionType.SELF):
|
||||||
CrossAttentionControl.CrossAttentionType.SELF):
|
|
||||||
self.self_cross_attention_module_identifiers.append(name)
|
self.self_cross_attention_module_identifiers.append(name)
|
||||||
for name,module in CrossAttentionControl.get_attention_modules(model,
|
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS):
|
||||||
CrossAttentionControl.CrossAttentionType.TOKENS):
|
|
||||||
self.tokens_cross_attention_module_identifiers.append(name)
|
self.tokens_cross_attention_module_identifiers.append(name)
|
||||||
|
|
||||||
def request_save_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'):
|
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||||
if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF:
|
if cross_attention_type == CrossAttentionType.SELF:
|
||||||
self.self_cross_attention_action = CrossAttentionControl.Context.Action.SAVE
|
self.self_cross_attention_action = Context.Action.SAVE
|
||||||
else:
|
else:
|
||||||
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.SAVE
|
self.tokens_cross_attention_action = Context.Action.SAVE
|
||||||
|
|
||||||
def request_apply_saved_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'):
|
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||||
if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF:
|
if cross_attention_type == CrossAttentionType.SELF:
|
||||||
self.self_cross_attention_action = CrossAttentionControl.Context.Action.APPLY
|
self.self_cross_attention_action = Context.Action.APPLY
|
||||||
else:
|
else:
|
||||||
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.APPLY
|
self.tokens_cross_attention_action = Context.Action.APPLY
|
||||||
|
|
||||||
def is_tokens_cross_attention(self, module_identifier) -> bool:
|
def is_tokens_cross_attention(self, module_identifier) -> bool:
|
||||||
return module_identifier in self.tokens_cross_attention_module_identifiers
|
return module_identifier in self.tokens_cross_attention_module_identifiers
|
||||||
|
|
||||||
def get_should_save_maps(self, module_identifier: str) -> bool:
|
def get_should_save_maps(self, module_identifier: str) -> bool:
|
||||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||||
return self.self_cross_attention_action == CrossAttentionControl.Context.Action.SAVE
|
return self.self_cross_attention_action == Context.Action.SAVE
|
||||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||||
return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.SAVE
|
return self.tokens_cross_attention_action == Context.Action.SAVE
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
|
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
|
||||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||||
return self.self_cross_attention_action == CrossAttentionControl.Context.Action.APPLY
|
return self.self_cross_attention_action == Context.Action.APPLY
|
||||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||||
return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.APPLY
|
return self.tokens_cross_attention_action == Context.Action.APPLY
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
|
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
|
||||||
-> list['CrossAttentionControl.CrossAttentionType']:
|
-> list[CrossAttentionType]:
|
||||||
"""
|
"""
|
||||||
Should cross-attention control be applied on the given step?
|
Should cross-attention control be applied on the given step?
|
||||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
||||||
"""
|
"""
|
||||||
if percent_through is None:
|
if percent_through is None:
|
||||||
return [CrossAttentionControl.CrossAttentionType.SELF, CrossAttentionControl.CrossAttentionType.TOKENS]
|
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
|
||||||
|
|
||||||
opts = self.arguments.edit_options
|
opts = self.arguments.edit_options
|
||||||
to_control = []
|
to_control = []
|
||||||
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
|
if opts['s_start'] <= percent_through < opts['s_end']:
|
||||||
to_control.append(CrossAttentionControl.CrossAttentionType.SELF)
|
to_control.append(CrossAttentionType.SELF)
|
||||||
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
|
if opts['t_start'] <= percent_through < opts['t_end']:
|
||||||
to_control.append(CrossAttentionControl.CrossAttentionType.TOKENS)
|
to_control.append(CrossAttentionType.TOKENS)
|
||||||
return to_control
|
return to_control
|
||||||
|
|
||||||
def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int,
|
def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int,
|
||||||
@ -132,7 +139,7 @@ class CrossAttentionControl:
|
|||||||
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
|
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
|
||||||
return saved_attention_dict['slices'][requested_offset]
|
return saved_attention_dict['slices'][requested_offset]
|
||||||
|
|
||||||
if saved_attention_dict['dim'] == None:
|
if saved_attention_dict['dim'] is None:
|
||||||
whole_saved_attention = saved_attention_dict['slices'][0]
|
whole_saved_attention = saved_attention_dict['slices'][0]
|
||||||
if requested_dim == 0:
|
if requested_dim == 0:
|
||||||
return whole_saved_attention[requested_offset:requested_offset + slice_size]
|
return whole_saved_attention[requested_offset:requested_offset + slice_size]
|
||||||
@ -141,15 +148,15 @@ class CrossAttentionControl:
|
|||||||
|
|
||||||
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
||||||
|
|
||||||
def get_slicing_strategy(self, identifier: str) -> Optional[tuple[int, int]]:
|
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
||||||
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
||||||
if saved_attention is None:
|
if saved_attention is None:
|
||||||
return None, None
|
return None, None
|
||||||
return saved_attention['dim'], saved_attention['slice_size']
|
return saved_attention['dim'], saved_attention['slice_size']
|
||||||
|
|
||||||
def clear_requests(self, cleanup=True):
|
def clear_requests(self, cleanup=True):
|
||||||
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.NONE
|
self.tokens_cross_attention_action = Context.Action.NONE
|
||||||
self.self_cross_attention_action = CrossAttentionControl.Context.Action.NONE
|
self.self_cross_attention_action = Context.Action.NONE
|
||||||
if cleanup:
|
if cleanup:
|
||||||
self.saved_cross_attention_maps = {}
|
self.saved_cross_attention_maps = {}
|
||||||
|
|
||||||
@ -158,12 +165,12 @@ class CrossAttentionControl:
|
|||||||
for offset, slice in map_dict['slices'].items():
|
for offset, slice in map_dict['slices'].items():
|
||||||
map_dict[offset] = slice.to('cpu')
|
map_dict[offset] = slice.to('cpu')
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def remove_cross_attention_control(cls, model):
|
|
||||||
cls.remove_attention_function(model)
|
|
||||||
|
|
||||||
@classmethod
|
def remove_cross_attention_control(model):
|
||||||
def setup_cross_attention_control(cls, model, context: Context):
|
remove_attention_function(model)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_cross_attention_control(model, context: Context):
|
||||||
"""
|
"""
|
||||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||||
|
|
||||||
@ -191,22 +198,16 @@ class CrossAttentionControl:
|
|||||||
context.register_cross_attention_modules(model)
|
context.register_cross_attention_modules(model)
|
||||||
context.cross_attention_mask = mask.to(device)
|
context.cross_attention_mask = mask.to(device)
|
||||||
context.cross_attention_index_map = indices.to(device)
|
context.cross_attention_index_map = indices.to(device)
|
||||||
cls.inject_attention_function(model, context)
|
inject_attention_function(model, context)
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionType(enum.Enum):
|
def get_attention_modules(model, which: CrossAttentionType):
|
||||||
SELF = 1
|
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||||
TOKENS = 2
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_attention_modules(cls, model, which: CrossAttentionType):
|
|
||||||
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
|
|
||||||
return [(name,module) for name, module in model.named_modules() if
|
return [(name,module) for name, module in model.named_modules() if
|
||||||
type(module).__name__ == "CrossAttention" and which_attn in name]
|
type(module).__name__ == "CrossAttention" and which_attn in name]
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
def inject_attention_function(unet, context: Context):
|
||||||
def inject_attention_function(cls, unet, context: 'CrossAttentionControl.Context'):
|
|
||||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||||
|
|
||||||
def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size):
|
def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size):
|
||||||
@ -251,12 +252,11 @@ class CrossAttentionControl:
|
|||||||
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \
|
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \
|
||||||
context.get_slicing_strategy(module_identifier))
|
context.get_slicing_strategy(module_identifier))
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def remove_attention_function(cls, unet):
|
def remove_attention_function(unet):
|
||||||
# clear wrangler callback
|
# clear wrangler callback
|
||||||
for name, module in unet.named_modules():
|
for name, module in unet.named_modules():
|
||||||
module_name = type(module).__name__
|
module_name = type(module).__name__
|
||||||
if module_name == "CrossAttention":
|
if module_name == "CrossAttention":
|
||||||
module.set_attention_slice_wrangler(None)
|
module.set_attention_slice_wrangler(None)
|
||||||
module.set_slicing_strategy_getter(None)
|
module.set_slicing_strategy_getter(None)
|
||||||
|
|
||||||
|
@ -4,7 +4,8 @@ from typing import Callable, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
|
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||||
|
remove_cross_attention_control, setup_cross_attention_control, Context
|
||||||
from ldm.modules.attention import get_mem_free_total
|
from ldm.modules.attention import get_mem_free_total
|
||||||
|
|
||||||
|
|
||||||
@ -20,7 +21,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
|
|
||||||
class ExtraConditioningInfo:
|
class ExtraConditioningInfo:
|
||||||
def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]):
|
def __init__(self, cross_attention_control_args: Optional[Arguments]):
|
||||||
self.cross_attention_control_args = cross_attention_control_args
|
self.cross_attention_control_args = cross_attention_control_args
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -40,16 +41,16 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
|
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
|
||||||
self.conditioning = conditioning
|
self.conditioning = conditioning
|
||||||
self.cross_attention_control_context = CrossAttentionControl.Context(
|
self.cross_attention_control_context = Context(
|
||||||
arguments=self.conditioning.cross_attention_control_args,
|
arguments=self.conditioning.cross_attention_control_args,
|
||||||
step_count=step_count
|
step_count=step_count
|
||||||
)
|
)
|
||||||
CrossAttentionControl.setup_cross_attention_control(self.model, self.cross_attention_control_context)
|
setup_cross_attention_control(self.model, self.cross_attention_control_context)
|
||||||
|
|
||||||
def remove_cross_attention_control(self):
|
def remove_cross_attention_control(self):
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
CrossAttentionControl.remove_cross_attention_control(self.model)
|
remove_cross_attention_control(self.model)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -71,7 +72,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
|
|
||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
context: CrossAttentionControl.Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||||
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
|
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||||
@ -133,7 +134,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||||
context:CrossAttentionControl.Context = self.cross_attention_control_context
|
context:Context = self.cross_attention_control_context
|
||||||
|
|
||||||
try:
|
try:
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user