mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
rename StrcuturedConditioning to ExtraConditioningInfo
This commit is contained in:
parent
8142b72bcd
commit
2bf9f1f0d8
@ -34,7 +34,7 @@ class Img2Img(Generator):
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c, ec, edit_opcodes = conditioning
|
||||
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
|
@ -22,7 +22,7 @@ class Txt2Img(Generator):
|
||||
"""
|
||||
self.perlin = perlin
|
||||
uc, c, ec, edit_opcodes = conditioning
|
||||
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
|
@ -24,7 +24,7 @@ class Txt2Img2Img(Generator):
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
uc, c, ec, edit_opcodes = conditioning
|
||||
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
|
@ -6,23 +6,6 @@ import torch
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
|
||||
class StructuredConditioning:
|
||||
def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None):
|
||||
"""
|
||||
:param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768)
|
||||
:param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens
|
||||
"""
|
||||
# TODO migrate conditioning and unconditioning here, too
|
||||
#self.conditioning = conditioning
|
||||
#self.unconditioning = unconditioning
|
||||
self.edited_conditioning = edited_conditioning
|
||||
self.edit_opcodes = edit_opcodes
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.edited_conditioning is not None
|
||||
|
||||
'''
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
all InvokeAI diffusion procedures.
|
||||
@ -31,6 +14,20 @@ class InvokeAIDiffuserComponent:
|
||||
* Cross Attention Control ("prompt2prompt")
|
||||
'''
|
||||
|
||||
|
||||
class ExtraConditioningInfo:
|
||||
def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None):
|
||||
"""
|
||||
:param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768)
|
||||
:param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens
|
||||
"""
|
||||
self.edited_conditioning = edited_conditioning
|
||||
self.edit_opcodes = edit_opcodes
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.edited_conditioning is not None
|
||||
|
||||
def __init__(self, model, model_forward_callback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]):
|
||||
"""
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
@ -40,7 +37,7 @@ class InvokeAIDiffuserComponent:
|
||||
self.model_forward_callback = model_forward_callback
|
||||
|
||||
|
||||
def setup_cross_attention_control(self, conditioning: StructuredConditioning):
|
||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo):
|
||||
self.conditioning = conditioning
|
||||
CrossAttentionControl.setup_cross_attention_control(self.model, conditioning.edited_conditioning, conditioning.edit_opcodes)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user