mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactored single diffusion path seems to be working for all samplers
This commit is contained in:
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
from ldm.invoke.devices import choose_autocast
|
from ldm.invoke.devices import choose_autocast
|
||||||
from ldm.invoke.generator.base import Generator
|
from ldm.invoke.generator.base import Generator
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
|
||||||
class Img2Img(Generator):
|
class Img2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
@ -33,6 +34,7 @@ class Img2Img(Generator):
|
|||||||
|
|
||||||
t_enc = int(strength * steps)
|
t_enc = int(strength * steps)
|
||||||
uc, c, ec, edit_opcodes = conditioning
|
uc, c, ec, edit_opcodes = conditioning
|
||||||
|
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||||
|
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
@ -50,8 +52,7 @@ class Img2Img(Generator):
|
|||||||
unconditional_guidance_scale=cfg_scale,
|
unconditional_guidance_scale=cfg_scale,
|
||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
init_latent = self.init_latent,
|
init_latent = self.init_latent,
|
||||||
edited_conditioning = ec,
|
structured_conditioning = structured_conditioning
|
||||||
conditioning_edit_opcodes = edit_opcodes
|
|
||||||
# changes how noising is performed in ksampler
|
# changes how noising is performed in ksampler
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -5,6 +5,8 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ldm.invoke.generator.base import Generator
|
from ldm.invoke.generator.base import Generator
|
||||||
|
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
|
||||||
|
|
||||||
class Txt2Img(Generator):
|
class Txt2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
@ -20,6 +22,7 @@ class Txt2Img(Generator):
|
|||||||
"""
|
"""
|
||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
uc, c, ec, edit_opcodes = conditioning
|
uc, c, ec, edit_opcodes = conditioning
|
||||||
|
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
@ -43,8 +46,7 @@ class Txt2Img(Generator):
|
|||||||
verbose = False,
|
verbose = False,
|
||||||
unconditional_guidance_scale = cfg_scale,
|
unconditional_guidance_scale = cfg_scale,
|
||||||
unconditional_conditioning = uc,
|
unconditional_conditioning = uc,
|
||||||
edited_conditioning = ec,
|
structured_conditioning = structured_conditioning,
|
||||||
conditioning_edit_opcodes = edit_opcodes,
|
|
||||||
eta = ddim_eta,
|
eta = ddim_eta,
|
||||||
img_callback = step_callback,
|
img_callback = step_callback,
|
||||||
threshold = threshold,
|
threshold = threshold,
|
||||||
|
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
import math
|
import math
|
||||||
from ldm.invoke.generator.base import Generator
|
from ldm.invoke.generator.base import Generator
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
|
||||||
|
|
||||||
class Txt2Img2Img(Generator):
|
class Txt2Img2Img(Generator):
|
||||||
@ -23,6 +24,7 @@ class Txt2Img2Img(Generator):
|
|||||||
kwargs are 'width' and 'height'
|
kwargs are 'width' and 'height'
|
||||||
"""
|
"""
|
||||||
uc, c, ec, edit_opcodes = conditioning
|
uc, c, ec, edit_opcodes = conditioning
|
||||||
|
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
@ -61,8 +63,7 @@ class Txt2Img2Img(Generator):
|
|||||||
unconditional_conditioning = uc,
|
unconditional_conditioning = uc,
|
||||||
eta = ddim_eta,
|
eta = ddim_eta,
|
||||||
img_callback = step_callback,
|
img_callback = step_callback,
|
||||||
edited_conditioning = ec,
|
structured_conditioning = structured_conditioning
|
||||||
conditioning_edit_opcodes = edit_opcodes
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
@ -96,8 +97,7 @@ class Txt2Img2Img(Generator):
|
|||||||
img_callback = step_callback,
|
img_callback = step_callback,
|
||||||
unconditional_guidance_scale=cfg_scale,
|
unconditional_guidance_scale=cfg_scale,
|
||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
edited_conditioning = ec,
|
structured_conditioning = structured_conditioning
|
||||||
conditioning_edit_opcodes = edit_opcodes
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.free_gpu_mem:
|
if self.free_gpu_mem:
|
||||||
|
@ -2,10 +2,6 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
from functools import partial
|
|
||||||
from ldm.invoke.devices import choose_torch_device
|
|
||||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from ldm.models.diffusion.sampler import Sampler
|
from ldm.models.diffusion.sampler import Sampler
|
||||||
from ldm.modules.diffusionmodules.util import noise_like
|
from ldm.modules.diffusionmodules.util import noise_like
|
||||||
@ -20,13 +16,12 @@ class DDIMSampler(Sampler):
|
|||||||
def prepare_to_sample(self, t_enc, **kwargs):
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
super().prepare_to_sample(t_enc, **kwargs)
|
super().prepare_to_sample(t_enc, **kwargs)
|
||||||
|
|
||||||
edited_conditioning = kwargs.get('edited_conditioning', None)
|
structured_conditioning = kwargs.get('structured_conditioning', None)
|
||||||
|
|
||||||
if edited_conditioning is not None:
|
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
|
||||||
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
|
||||||
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, edit_opcodes)
|
|
||||||
else:
|
else:
|
||||||
self.invokeai_diffuser.cleanup_cross_attention_control()
|
self.invokeai_diffuser.remove_cross_attention_control()
|
||||||
|
|
||||||
|
|
||||||
# This is the central routine
|
# This is the central routine
|
||||||
@ -54,10 +49,9 @@ class DDIMSampler(Sampler):
|
|||||||
unconditional_conditioning is None
|
unconditional_conditioning is None
|
||||||
or unconditional_guidance_scale == 1.0
|
or unconditional_guidance_scale == 1.0
|
||||||
):
|
):
|
||||||
# damian0815 does not think this code path is ever used
|
# damian0815 would like to know when/if this code path is used
|
||||||
e_t = self.model.apply_model(x, t, c)
|
e_t = self.model.apply_model(x, t, c)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale)
|
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale)
|
||||||
|
|
||||||
if score_corrector is not None:
|
if score_corrector is not None:
|
||||||
|
@ -34,18 +34,17 @@ class CFGDenoiser(nn.Module):
|
|||||||
|
|
||||||
def prepare_to_sample(self, t_enc, **kwargs):
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
|
|
||||||
edited_conditioning = kwargs.get('edited_conditioning', None)
|
structured_conditioning = kwargs.get('structured_conditioning', None)
|
||||||
|
|
||||||
if edited_conditioning is not None:
|
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
|
||||||
conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
|
||||||
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, conditioning_edit_opcodes)
|
|
||||||
else:
|
else:
|
||||||
self.invokeai_diffuser.cleanup_cross_attention_control()
|
self.invokeai_diffuser.remove_cross_attention_control()
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
|
|
||||||
final_next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||||
|
|
||||||
# apply threshold
|
# apply threshold
|
||||||
if self.warmup < self.warmup_max:
|
if self.warmup < self.warmup_max:
|
||||||
@ -55,7 +54,7 @@ class CFGDenoiser(nn.Module):
|
|||||||
thresh = self.threshold
|
thresh = self.threshold
|
||||||
if thresh > self.threshold:
|
if thresh > self.threshold:
|
||||||
thresh = self.threshold
|
thresh = self.threshold
|
||||||
return cfg_apply_threshold(final_next_x, thresh)
|
return cfg_apply_threshold(next_x, thresh)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -165,8 +164,7 @@ class KSampler(Sampler):
|
|||||||
log_every_t=100,
|
log_every_t=100,
|
||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
edited_conditioning=None,
|
structured_conditioning=None,
|
||||||
conditioning_edit_opcodes=None,
|
|
||||||
threshold = 0,
|
threshold = 0,
|
||||||
perlin = 0,
|
perlin = 0,
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
@ -199,7 +197,7 @@ class KSampler(Sampler):
|
|||||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||||
|
|
||||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||||
model_wrap_cfg.prepare_to_sample(S, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes)
|
model_wrap_cfg.prepare_to_sample(S, structured_conditioning=structured_conditioning)
|
||||||
extra_args = {
|
extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
@ -226,8 +224,7 @@ class KSampler(Sampler):
|
|||||||
index,
|
index,
|
||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
edited_conditioning=None,
|
structured_conditioning=None,
|
||||||
conditioning_edit_opcodes=None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if self.model_wrap is None:
|
if self.model_wrap is None:
|
||||||
@ -253,7 +250,7 @@ class KSampler(Sampler):
|
|||||||
# so the actual formula for indexing into sigmas:
|
# so the actual formula for indexing into sigmas:
|
||||||
# sigma_index = (steps-index)
|
# sigma_index = (steps-index)
|
||||||
s_index = t_enc - index - 1
|
s_index = t_enc - index - 1
|
||||||
self.model_wrap.prepare_to_sample(s_index, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes)
|
self.model_wrap.prepare_to_sample(s_index, structured_conditioning=structured_conditioning)
|
||||||
img = K.sampling.__dict__[f'_{self.schedule}'](
|
img = K.sampling.__dict__[f'_{self.schedule}'](
|
||||||
self.model_wrap,
|
self.model_wrap,
|
||||||
img,
|
img,
|
||||||
|
@ -20,13 +20,12 @@ class PLMSSampler(Sampler):
|
|||||||
def prepare_to_sample(self, t_enc, **kwargs):
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
super().prepare_to_sample(t_enc, **kwargs)
|
super().prepare_to_sample(t_enc, **kwargs)
|
||||||
|
|
||||||
edited_conditioning = kwargs.get('edited_conditioning', None)
|
structured_conditioning = kwargs.get('structured_conditioning', None)
|
||||||
|
|
||||||
if edited_conditioning is not None:
|
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
|
||||||
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
|
||||||
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, edit_opcodes)
|
|
||||||
else:
|
else:
|
||||||
self.invokeai_diffuser.cleanup_cross_attention_control()
|
self.invokeai_diffuser.remove_cross_attention_control()
|
||||||
|
|
||||||
|
|
||||||
# this is the essential routine
|
# this is the essential routine
|
||||||
@ -57,10 +56,9 @@ class PLMSSampler(Sampler):
|
|||||||
unconditional_conditioning is None
|
unconditional_conditioning is None
|
||||||
or unconditional_guidance_scale == 1.0
|
or unconditional_guidance_scale == 1.0
|
||||||
):
|
):
|
||||||
# damian0815 does not know if this code path is ever used
|
# damian0815 would like to know when/if this code path is used
|
||||||
e_t = self.model.apply_model(x, t, c)
|
e_t = self.model.apply_model(x, t, c)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale)
|
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale)
|
||||||
|
|
||||||
if score_corrector is not None:
|
if score_corrector is not None:
|
||||||
|
@ -6,17 +6,22 @@ import torch
|
|||||||
|
|
||||||
class InvokeAIDiffuserComponent:
|
class InvokeAIDiffuserComponent:
|
||||||
|
|
||||||
class Conditioning:
|
class StructuredConditioning:
|
||||||
def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None):
|
def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None):
|
||||||
"""
|
"""
|
||||||
:param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768)
|
:param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768)
|
||||||
:param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens
|
:param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens
|
||||||
"""
|
"""
|
||||||
|
# TODO migrate conditioning and unconditioning here, too
|
||||||
#self.conditioning = conditioning
|
#self.conditioning = conditioning
|
||||||
#self.unconditioning = unconditioning
|
#self.unconditioning = unconditioning
|
||||||
self.edited_conditioning = edited_conditioning
|
self.edited_conditioning = edited_conditioning
|
||||||
self.edit_opcodes = edit_opcodes
|
self.edit_opcodes = edit_opcodes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wants_cross_attention_control(self):
|
||||||
|
return self.edited_conditioning is not None
|
||||||
|
|
||||||
'''
|
'''
|
||||||
The aim of this component is to provide a single place for code that can be applied identically to
|
The aim of this component is to provide a single place for code that can be applied identically to
|
||||||
all InvokeAI diffusion procedures.
|
all InvokeAI diffusion procedures.
|
||||||
@ -34,14 +39,20 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.model_forward_callback = model_forward_callback
|
self.model_forward_callback = model_forward_callback
|
||||||
|
|
||||||
|
|
||||||
def setup_cross_attention_control(self, edited_conditioning, edit_opcodes):
|
def setup_cross_attention_control(self, conditioning: StructuredConditioning):
|
||||||
self.edited_conditioning = edited_conditioning
|
self.conditioning = conditioning
|
||||||
CrossAttentionControl.setup_attention_editing(self.model, edited_conditioning, edit_opcodes)
|
CrossAttentionControl.setup_cross_attention_control(self.model, conditioning.edited_conditioning, conditioning.edit_opcodes)
|
||||||
|
|
||||||
def cleanup_cross_attention_control(self):
|
def remove_cross_attention_control(self):
|
||||||
self.edited_conditioning = None
|
self.conditioning = None
|
||||||
CrossAttentionControl.clear_attention_editing(self.model)
|
CrossAttentionControl.remove_cross_attention_control(self.model)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def edited_conditioning(self):
|
||||||
|
if self.conditioning is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return self.conditioning.edited_conditioning
|
||||||
|
|
||||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||||
unconditioning: torch.Tensor, conditioning: torch.Tensor,
|
unconditioning: torch.Tensor, conditioning: torch.Tensor,
|
||||||
@ -78,13 +89,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# process x using the original prompt, saving the attention maps
|
# process x using the original prompt, saving the attention maps
|
||||||
CrossAttentionControl.request_save_attention_maps(self.model)
|
CrossAttentionControl.request_save_attention_maps(self.model)
|
||||||
_ = self.model_forward_callback(x, sigma, cond=conditioning)
|
_ = self.model_forward_callback(x, sigma, conditioning)
|
||||||
CrossAttentionControl.clear_requests(self.model)
|
CrossAttentionControl.clear_requests(self.model)
|
||||||
|
|
||||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||||
CrossAttentionControl.request_apply_saved_attention_maps(self.model)
|
CrossAttentionControl.request_apply_saved_attention_maps(self.model)
|
||||||
conditioned_next_x = self.model_forward_callback(x, sigma, self.edited_conditioning)
|
conditioned_next_x = self.model_forward_callback(x, sigma, self.edited_conditioning)
|
||||||
CrossAttentionControl.clear_requests(model)
|
CrossAttentionControl.clear_requests(self.model)
|
||||||
|
|
||||||
|
|
||||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||||
@ -100,14 +111,16 @@ class CrossAttentionControl:
|
|||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clear_attention_editing(cls, model):
|
def remove_cross_attention_control(cls, model):
|
||||||
cls.remove_attention_function(model)
|
cls.remove_attention_function(model)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_attention_editing(cls, model,
|
def setup_cross_attention_control(cls, model,
|
||||||
substitute_conditioning: torch.Tensor,
|
substitute_conditioning: torch.Tensor,
|
||||||
edit_opcodes: list):
|
edit_opcodes: list):
|
||||||
"""
|
"""
|
||||||
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||||
|
|
||||||
:param model: The unet model to inject into.
|
:param model: The unet model to inject into.
|
||||||
:param substitute_conditioning: The "edited" conditioning vector, [Bx77x768]
|
:param substitute_conditioning: The "edited" conditioning vector, [Bx77x768]
|
||||||
:param edit_opcodes: Opcodes from difflib.SequenceMatcher describing how the base
|
:param edit_opcodes: Opcodes from difflib.SequenceMatcher describing how the base
|
||||||
|
Reference in New Issue
Block a user