mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactored single diffusion path seems to be working for all samplers
This commit is contained in:
parent
147d39cb7c
commit
1ffd4a9e06
@ -7,6 +7,7 @@ import numpy as np
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@ -33,6 +34,7 @@ class Img2Img(Generator):
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c, ec, edit_opcodes = conditioning
|
||||
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
@ -50,8 +52,7 @@ class Img2Img(Generator):
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
init_latent = self.init_latent,
|
||||
edited_conditioning = ec,
|
||||
conditioning_edit_opcodes = edit_opcodes
|
||||
structured_conditioning = structured_conditioning
|
||||
# changes how noising is performed in ksampler
|
||||
)
|
||||
|
||||
|
@ -5,6 +5,8 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||
import torch
|
||||
import numpy as np
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@ -20,6 +22,7 @@ class Txt2Img(Generator):
|
||||
"""
|
||||
self.perlin = perlin
|
||||
uc, c, ec, edit_opcodes = conditioning
|
||||
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
@ -43,8 +46,7 @@ class Txt2Img(Generator):
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
edited_conditioning = ec,
|
||||
conditioning_edit_opcodes = edit_opcodes,
|
||||
structured_conditioning = structured_conditioning,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback,
|
||||
threshold = threshold,
|
||||
|
@ -7,6 +7,7 @@ import numpy as np
|
||||
import math
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
@ -23,6 +24,7 @@ class Txt2Img2Img(Generator):
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
uc, c, ec, edit_opcodes = conditioning
|
||||
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
@ -61,8 +63,7 @@ class Txt2Img2Img(Generator):
|
||||
unconditional_conditioning = uc,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback,
|
||||
edited_conditioning = ec,
|
||||
conditioning_edit_opcodes = edit_opcodes
|
||||
structured_conditioning = structured_conditioning
|
||||
)
|
||||
|
||||
print(
|
||||
@ -96,8 +97,7 @@ class Txt2Img2Img(Generator):
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
edited_conditioning = ec,
|
||||
conditioning_edit_opcodes = edit_opcodes
|
||||
structured_conditioning = structured_conditioning
|
||||
)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
|
@ -2,10 +2,6 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
@ -20,13 +16,12 @@ class DDIMSampler(Sampler):
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
edited_conditioning = kwargs.get('edited_conditioning', None)
|
||||
structured_conditioning = kwargs.get('structured_conditioning', None)
|
||||
|
||||
if edited_conditioning is not None:
|
||||
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, edit_opcodes)
|
||||
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
|
||||
else:
|
||||
self.invokeai_diffuser.cleanup_cross_attention_control()
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
||||
# This is the central routine
|
||||
@ -54,10 +49,9 @@ class DDIMSampler(Sampler):
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
# damian0815 does not think this code path is ever used
|
||||
# damian0815 would like to know when/if this code path is used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale)
|
||||
|
||||
if score_corrector is not None:
|
||||
|
@ -34,18 +34,17 @@ class CFGDenoiser(nn.Module):
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
|
||||
edited_conditioning = kwargs.get('edited_conditioning', None)
|
||||
structured_conditioning = kwargs.get('structured_conditioning', None)
|
||||
|
||||
if edited_conditioning is not None:
|
||||
conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, conditioning_edit_opcodes)
|
||||
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
|
||||
else:
|
||||
self.invokeai_diffuser.cleanup_cross_attention_control()
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
|
||||
final_next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
|
||||
# apply threshold
|
||||
if self.warmup < self.warmup_max:
|
||||
@ -55,7 +54,7 @@ class CFGDenoiser(nn.Module):
|
||||
thresh = self.threshold
|
||||
if thresh > self.threshold:
|
||||
thresh = self.threshold
|
||||
return cfg_apply_threshold(final_next_x, thresh)
|
||||
return cfg_apply_threshold(next_x, thresh)
|
||||
|
||||
|
||||
|
||||
@ -165,8 +164,7 @@ class KSampler(Sampler):
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
edited_conditioning=None,
|
||||
conditioning_edit_opcodes=None,
|
||||
structured_conditioning=None,
|
||||
threshold = 0,
|
||||
perlin = 0,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
@ -199,7 +197,7 @@ class KSampler(Sampler):
|
||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
model_wrap_cfg.prepare_to_sample(S, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes)
|
||||
model_wrap_cfg.prepare_to_sample(S, structured_conditioning=structured_conditioning)
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
@ -226,8 +224,7 @@ class KSampler(Sampler):
|
||||
index,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
edited_conditioning=None,
|
||||
conditioning_edit_opcodes=None,
|
||||
structured_conditioning=None,
|
||||
**kwargs,
|
||||
):
|
||||
if self.model_wrap is None:
|
||||
@ -253,7 +250,7 @@ class KSampler(Sampler):
|
||||
# so the actual formula for indexing into sigmas:
|
||||
# sigma_index = (steps-index)
|
||||
s_index = t_enc - index - 1
|
||||
self.model_wrap.prepare_to_sample(s_index, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes)
|
||||
self.model_wrap.prepare_to_sample(s_index, structured_conditioning=structured_conditioning)
|
||||
img = K.sampling.__dict__[f'_{self.schedule}'](
|
||||
self.model_wrap,
|
||||
img,
|
||||
|
@ -20,13 +20,12 @@ class PLMSSampler(Sampler):
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
edited_conditioning = kwargs.get('edited_conditioning', None)
|
||||
structured_conditioning = kwargs.get('structured_conditioning', None)
|
||||
|
||||
if edited_conditioning is not None:
|
||||
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, edit_opcodes)
|
||||
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
|
||||
else:
|
||||
self.invokeai_diffuser.cleanup_cross_attention_control()
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
||||
# this is the essential routine
|
||||
@ -57,10 +56,9 @@ class PLMSSampler(Sampler):
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
# damian0815 does not know if this code path is ever used
|
||||
# damian0815 would like to know when/if this code path is used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale)
|
||||
|
||||
if score_corrector is not None:
|
||||
|
@ -6,17 +6,22 @@ import torch
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
|
||||
class Conditioning:
|
||||
class StructuredConditioning:
|
||||
def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None):
|
||||
"""
|
||||
:param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768)
|
||||
:param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens
|
||||
"""
|
||||
# TODO migrate conditioning and unconditioning here, too
|
||||
#self.conditioning = conditioning
|
||||
#self.unconditioning = unconditioning
|
||||
self.edited_conditioning = edited_conditioning
|
||||
self.edit_opcodes = edit_opcodes
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.edited_conditioning is not None
|
||||
|
||||
'''
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
all InvokeAI diffusion procedures.
|
||||
@ -34,14 +39,20 @@ class InvokeAIDiffuserComponent:
|
||||
self.model_forward_callback = model_forward_callback
|
||||
|
||||
|
||||
def setup_cross_attention_control(self, edited_conditioning, edit_opcodes):
|
||||
self.edited_conditioning = edited_conditioning
|
||||
CrossAttentionControl.setup_attention_editing(self.model, edited_conditioning, edit_opcodes)
|
||||
def setup_cross_attention_control(self, conditioning: StructuredConditioning):
|
||||
self.conditioning = conditioning
|
||||
CrossAttentionControl.setup_cross_attention_control(self.model, conditioning.edited_conditioning, conditioning.edit_opcodes)
|
||||
|
||||
def cleanup_cross_attention_control(self):
|
||||
self.edited_conditioning = None
|
||||
CrossAttentionControl.clear_attention_editing(self.model)
|
||||
def remove_cross_attention_control(self):
|
||||
self.conditioning = None
|
||||
CrossAttentionControl.remove_cross_attention_control(self.model)
|
||||
|
||||
@property
|
||||
def edited_conditioning(self):
|
||||
if self.conditioning is None:
|
||||
return None
|
||||
else:
|
||||
return self.conditioning.edited_conditioning
|
||||
|
||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||
unconditioning: torch.Tensor, conditioning: torch.Tensor,
|
||||
@ -78,13 +89,13 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
CrossAttentionControl.request_save_attention_maps(self.model)
|
||||
_ = self.model_forward_callback(x, sigma, cond=conditioning)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
CrossAttentionControl.request_apply_saved_attention_maps(self.model)
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, self.edited_conditioning)
|
||||
CrossAttentionControl.clear_requests(model)
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
@ -100,14 +111,16 @@ class CrossAttentionControl:
|
||||
|
||||
|
||||
@classmethod
|
||||
def clear_attention_editing(cls, model):
|
||||
def remove_cross_attention_control(cls, model):
|
||||
cls.remove_attention_function(model)
|
||||
|
||||
@classmethod
|
||||
def setup_attention_editing(cls, model,
|
||||
substitute_conditioning: torch.Tensor,
|
||||
edit_opcodes: list):
|
||||
def setup_cross_attention_control(cls, model,
|
||||
substitute_conditioning: torch.Tensor,
|
||||
edit_opcodes: list):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
:param model: The unet model to inject into.
|
||||
:param substitute_conditioning: The "edited" conditioning vector, [Bx77x768]
|
||||
:param edit_opcodes: Opcodes from difflib.SequenceMatcher describing how the base
|
||||
|
Loading…
Reference in New Issue
Block a user