refactored single diffusion path seems to be working for all samplers

This commit is contained in:
Damian at mba 2022-10-19 19:57:20 +02:00
parent 147d39cb7c
commit 1ffd4a9e06
7 changed files with 57 additions and 52 deletions

View File

@ -7,6 +7,7 @@ import numpy as np
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Img2Img(Generator):
def __init__(self, model, precision):
@ -33,6 +34,7 @@ class Img2Img(Generator):
t_enc = int(strength * steps)
uc, c, ec, edit_opcodes = conditioning
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
def make_image(x_T):
# encode (scaled latent)
@ -50,8 +52,7 @@ class Img2Img(Generator):
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
init_latent = self.init_latent,
edited_conditioning = ec,
conditioning_edit_opcodes = edit_opcodes
structured_conditioning = structured_conditioning
# changes how noising is performed in ksampler
)

View File

@ -5,6 +5,8 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
import torch
import numpy as np
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Txt2Img(Generator):
def __init__(self, model, precision):
@ -20,6 +22,7 @@ class Txt2Img(Generator):
"""
self.perlin = perlin
uc, c, ec, edit_opcodes = conditioning
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
@torch.no_grad()
def make_image(x_T):
@ -43,8 +46,7 @@ class Txt2Img(Generator):
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
edited_conditioning = ec,
conditioning_edit_opcodes = edit_opcodes,
structured_conditioning = structured_conditioning,
eta = ddim_eta,
img_callback = step_callback,
threshold = threshold,

View File

@ -7,6 +7,7 @@ import numpy as np
import math
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Txt2Img2Img(Generator):
@ -23,6 +24,7 @@ class Txt2Img2Img(Generator):
kwargs are 'width' and 'height'
"""
uc, c, ec, edit_opcodes = conditioning
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
@torch.no_grad()
def make_image(x_T):
@ -61,8 +63,7 @@ class Txt2Img2Img(Generator):
unconditional_conditioning = uc,
eta = ddim_eta,
img_callback = step_callback,
edited_conditioning = ec,
conditioning_edit_opcodes = edit_opcodes
structured_conditioning = structured_conditioning
)
print(
@ -96,8 +97,7 @@ class Txt2Img2Img(Generator):
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
edited_conditioning = ec,
conditioning_edit_opcodes = edit_opcodes
structured_conditioning = structured_conditioning
)
if self.free_gpu_mem:

View File

@ -2,10 +2,6 @@
from typing import Union
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.models.diffusion.sampler import Sampler
from ldm.modules.diffusionmodules.util import noise_like
@ -20,13 +16,12 @@ class DDIMSampler(Sampler):
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
edited_conditioning = kwargs.get('edited_conditioning', None)
structured_conditioning = kwargs.get('structured_conditioning', None)
if edited_conditioning is not None:
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, edit_opcodes)
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
else:
self.invokeai_diffuser.cleanup_cross_attention_control()
self.invokeai_diffuser.remove_cross_attention_control()
# This is the central routine
@ -54,10 +49,9 @@ class DDIMSampler(Sampler):
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
# damian0815 does not think this code path is ever used
# damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c)
else:
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale)
if score_corrector is not None:

View File

@ -34,18 +34,17 @@ class CFGDenoiser(nn.Module):
def prepare_to_sample(self, t_enc, **kwargs):
edited_conditioning = kwargs.get('edited_conditioning', None)
structured_conditioning = kwargs.get('structured_conditioning', None)
if edited_conditioning is not None:
conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, conditioning_edit_opcodes)
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
else:
self.invokeai_diffuser.cleanup_cross_attention_control()
self.invokeai_diffuser.remove_cross_attention_control()
def forward(self, x, sigma, uncond, cond, cond_scale):
final_next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
# apply threshold
if self.warmup < self.warmup_max:
@ -55,7 +54,7 @@ class CFGDenoiser(nn.Module):
thresh = self.threshold
if thresh > self.threshold:
thresh = self.threshold
return cfg_apply_threshold(final_next_x, thresh)
return cfg_apply_threshold(next_x, thresh)
@ -165,8 +164,7 @@ class KSampler(Sampler):
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
edited_conditioning=None,
conditioning_edit_opcodes=None,
structured_conditioning=None,
threshold = 0,
perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -199,7 +197,7 @@ class KSampler(Sampler):
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg.prepare_to_sample(S, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes)
model_wrap_cfg.prepare_to_sample(S, structured_conditioning=structured_conditioning)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,
@ -226,8 +224,7 @@ class KSampler(Sampler):
index,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
edited_conditioning=None,
conditioning_edit_opcodes=None,
structured_conditioning=None,
**kwargs,
):
if self.model_wrap is None:
@ -253,7 +250,7 @@ class KSampler(Sampler):
# so the actual formula for indexing into sigmas:
# sigma_index = (steps-index)
s_index = t_enc - index - 1
self.model_wrap.prepare_to_sample(s_index, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes)
self.model_wrap.prepare_to_sample(s_index, structured_conditioning=structured_conditioning)
img = K.sampling.__dict__[f'_{self.schedule}'](
self.model_wrap,
img,

View File

@ -20,13 +20,12 @@ class PLMSSampler(Sampler):
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
edited_conditioning = kwargs.get('edited_conditioning', None)
structured_conditioning = kwargs.get('structured_conditioning', None)
if edited_conditioning is not None:
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, edit_opcodes)
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
else:
self.invokeai_diffuser.cleanup_cross_attention_control()
self.invokeai_diffuser.remove_cross_attention_control()
# this is the essential routine
@ -57,10 +56,9 @@ class PLMSSampler(Sampler):
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
# damian0815 does not know if this code path is ever used
# damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c)
else:
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale)
if score_corrector is not None:

View File

@ -6,17 +6,22 @@ import torch
class InvokeAIDiffuserComponent:
class Conditioning:
class StructuredConditioning:
def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None):
"""
:param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768)
:param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens
"""
# TODO migrate conditioning and unconditioning here, too
#self.conditioning = conditioning
#self.unconditioning = unconditioning
self.edited_conditioning = edited_conditioning
self.edit_opcodes = edit_opcodes
@property
def wants_cross_attention_control(self):
return self.edited_conditioning is not None
'''
The aim of this component is to provide a single place for code that can be applied identically to
all InvokeAI diffusion procedures.
@ -34,14 +39,20 @@ class InvokeAIDiffuserComponent:
self.model_forward_callback = model_forward_callback
def setup_cross_attention_control(self, edited_conditioning, edit_opcodes):
self.edited_conditioning = edited_conditioning
CrossAttentionControl.setup_attention_editing(self.model, edited_conditioning, edit_opcodes)
def setup_cross_attention_control(self, conditioning: StructuredConditioning):
self.conditioning = conditioning
CrossAttentionControl.setup_cross_attention_control(self.model, conditioning.edited_conditioning, conditioning.edit_opcodes)
def cleanup_cross_attention_control(self):
self.edited_conditioning = None
CrossAttentionControl.clear_attention_editing(self.model)
def remove_cross_attention_control(self):
self.conditioning = None
CrossAttentionControl.remove_cross_attention_control(self.model)
@property
def edited_conditioning(self):
if self.conditioning is None:
return None
else:
return self.conditioning.edited_conditioning
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: torch.Tensor, conditioning: torch.Tensor,
@ -78,13 +89,13 @@ class InvokeAIDiffuserComponent:
# process x using the original prompt, saving the attention maps
CrossAttentionControl.request_save_attention_maps(self.model)
_ = self.model_forward_callback(x, sigma, cond=conditioning)
_ = self.model_forward_callback(x, sigma, conditioning)
CrossAttentionControl.clear_requests(self.model)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
CrossAttentionControl.request_apply_saved_attention_maps(self.model)
conditioned_next_x = self.model_forward_callback(x, sigma, self.edited_conditioning)
CrossAttentionControl.clear_requests(model)
CrossAttentionControl.clear_requests(self.model)
# to scale how much effect conditioning has, calculate the changes it does and then scale that
@ -100,14 +111,16 @@ class CrossAttentionControl:
@classmethod
def clear_attention_editing(cls, model):
def remove_cross_attention_control(cls, model):
cls.remove_attention_function(model)
@classmethod
def setup_attention_editing(cls, model,
substitute_conditioning: torch.Tensor,
edit_opcodes: list):
def setup_cross_attention_control(cls, model,
substitute_conditioning: torch.Tensor,
edit_opcodes: list):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
:param model: The unet model to inject into.
:param substitute_conditioning: The "edited" conditioning vector, [Bx77x768]
:param edit_opcodes: Opcodes from difflib.SequenceMatcher describing how the base