move all prompting stuff to use compel

This commit is contained in:
Damian Stewart
2023-02-19 20:42:29 +01:00
parent b9ecf93ba3
commit ded3f13a33
11 changed files with 104 additions and 1646 deletions

View File

@ -1,3 +1,8 @@
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
import enum
import math
from typing import Optional, Callable
@ -6,35 +11,13 @@ import psutil
import torch
import diffusers
from torch import nn
from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.cross_attention import AttnProcessor
from ldm.invoke.devices import torch_dtype
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
"""
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768]
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required)
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes.
"""
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
self.edited_conditioning = edited_conditioning
self.edit_opcodes = edit_opcodes
if edited_conditioning is not None:
assert len(edit_opcodes) == len(edit_options), \
"there must be 1 edit_options dict for each edit_opcodes tuple"
non_none_edit_options = [x for x in edit_options if x is not None]
assert len(non_none_edit_options)>0, "missing edit_options"
if len(non_none_edit_options)>1:
print('warning: cross-attention control options are not working properly for >1 edit')
self.edit_options = non_none_edit_options[0]
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
@ -319,7 +302,6 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
Inject attention parameters and functions into the passed in model to enable cross attention editing.
:param model: The unet model to inject into.
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
:return: None
"""
@ -523,7 +505,7 @@ from dataclasses import field, dataclass
import torch
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor, AttnProcessor
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor
@dataclass