diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 6fa0d0c6dd..cfe3ff99bc 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -34,7 +34,7 @@ class Img2Img(Generator): t_enc = int(strength * steps) uc, c, ec, edit_opcodes = conditioning - extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) + extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes) def make_image(x_T): # encode (scaled latent) diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 657cccc592..7e739860c3 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -22,7 +22,7 @@ class Txt2Img(Generator): """ self.perlin = perlin uc, c, ec, edit_opcodes = conditioning - extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) + extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes) @torch.no_grad() def make_image(x_T): diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 64d0468418..2d67a44346 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -24,7 +24,7 @@ class Txt2Img2Img(Generator): kwargs are 'width' and 'height' """ uc, c, ec, edit_opcodes = conditioning - extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) + extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes) @torch.no_grad() def make_image(x_T): diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 0a613091d5..290925fc8c 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -6,23 +6,6 @@ import torch class InvokeAIDiffuserComponent: - - class StructuredConditioning: - def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None): - """ - :param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768) - :param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens - """ - # TODO migrate conditioning and unconditioning here, too - #self.conditioning = conditioning - #self.unconditioning = unconditioning - self.edited_conditioning = edited_conditioning - self.edit_opcodes = edit_opcodes - - @property - def wants_cross_attention_control(self): - return self.edited_conditioning is not None - ''' The aim of this component is to provide a single place for code that can be applied identically to all InvokeAI diffusion procedures. @@ -31,6 +14,20 @@ class InvokeAIDiffuserComponent: * Cross Attention Control ("prompt2prompt") ''' + + class ExtraConditioningInfo: + def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None): + """ + :param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768) + :param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens + """ + self.edited_conditioning = edited_conditioning + self.edit_opcodes = edit_opcodes + + @property + def wants_cross_attention_control(self): + return self.edited_conditioning is not None + def __init__(self, model, model_forward_callback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]): """ :param model: the unet model to pass through to cross attention control @@ -40,7 +37,7 @@ class InvokeAIDiffuserComponent: self.model_forward_callback = model_forward_callback - def setup_cross_attention_control(self, conditioning: StructuredConditioning): + def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo): self.conditioning = conditioning CrossAttentionControl.setup_cross_attention_control(self.model, conditioning.edited_conditioning, conditioning.edit_opcodes)