rename StrcuturedConditioning to ExtraConditioningInfo

This commit is contained in:
Damian at mba 2022-10-21 12:18:40 +02:00
parent 8142b72bcd
commit 2bf9f1f0d8
4 changed files with 18 additions and 21 deletions

View File

@ -34,7 +34,7 @@ class Img2Img(Generator):
t_enc = int(strength * steps)
uc, c, ec, edit_opcodes = conditioning
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes)
def make_image(x_T):
# encode (scaled latent)

View File

@ -22,7 +22,7 @@ class Txt2Img(Generator):
"""
self.perlin = perlin
uc, c, ec, edit_opcodes = conditioning
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes)
@torch.no_grad()
def make_image(x_T):

View File

@ -24,7 +24,7 @@ class Txt2Img2Img(Generator):
kwargs are 'width' and 'height'
"""
uc, c, ec, edit_opcodes = conditioning
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes)
@torch.no_grad()
def make_image(x_T):

View File

@ -6,23 +6,6 @@ import torch
class InvokeAIDiffuserComponent:
class StructuredConditioning:
def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None):
"""
:param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768)
:param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens
"""
# TODO migrate conditioning and unconditioning here, too
#self.conditioning = conditioning
#self.unconditioning = unconditioning
self.edited_conditioning = edited_conditioning
self.edit_opcodes = edit_opcodes
@property
def wants_cross_attention_control(self):
return self.edited_conditioning is not None
'''
The aim of this component is to provide a single place for code that can be applied identically to
all InvokeAI diffusion procedures.
@ -31,6 +14,20 @@ class InvokeAIDiffuserComponent:
* Cross Attention Control ("prompt2prompt")
'''
class ExtraConditioningInfo:
def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None):
"""
:param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768)
:param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens
"""
self.edited_conditioning = edited_conditioning
self.edit_opcodes = edit_opcodes
@property
def wants_cross_attention_control(self):
return self.edited_conditioning is not None
def __init__(self, model, model_forward_callback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]):
"""
:param model: the unet model to pass through to cross attention control
@ -40,7 +37,7 @@ class InvokeAIDiffuserComponent:
self.model_forward_callback = model_forward_callback
def setup_cross_attention_control(self, conditioning: StructuredConditioning):
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo):
self.conditioning = conditioning
CrossAttentionControl.setup_cross_attention_control(self.model, conditioning.edited_conditioning, conditioning.edit_opcodes)