mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip bringing cross-attention to PLMS and DDIM
This commit is contained in:
parent
09f62032ec
commit
54e6a68acb
@ -19,7 +19,7 @@ class Txt2Img(Generator):
|
|||||||
kwargs are 'width' and 'height'
|
kwargs are 'width' and 'height'
|
||||||
"""
|
"""
|
||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
uc, c, ec, edit_index_map = conditioning
|
uc, c, ec, edit_opcodes = conditioning
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
@ -44,7 +44,7 @@ class Txt2Img(Generator):
|
|||||||
unconditional_guidance_scale = cfg_scale,
|
unconditional_guidance_scale = cfg_scale,
|
||||||
unconditional_conditioning = uc,
|
unconditional_conditioning = uc,
|
||||||
edited_conditioning = ec,
|
edited_conditioning = ec,
|
||||||
edit_token_index_map = edit_index_map,
|
conditioning_edit_opcodes = edit_opcodes,
|
||||||
eta = ddim_eta,
|
eta = ddim_eta,
|
||||||
img_callback = step_callback,
|
img_callback = step_callback,
|
||||||
threshold = threshold,
|
threshold = threshold,
|
||||||
|
@ -2,6 +2,55 @@ from enum import Enum
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionControllableDiffusionMixin:
|
||||||
|
|
||||||
|
def setup_cross_attention_control_if_appropriate(self, model, edited_conditioning, edit_opcodes):
|
||||||
|
self.edited_conditioning = edited_conditioning
|
||||||
|
|
||||||
|
if edited_conditioning is not None:
|
||||||
|
# <start> a cat sitting on a car <end>
|
||||||
|
CrossAttentionControl.setup_attention_editing(model, edited_conditioning, edit_opcodes)
|
||||||
|
else:
|
||||||
|
# pass through the attention func but don't act on it
|
||||||
|
CrossAttentionControl.clear_attention_editing(model)
|
||||||
|
|
||||||
|
def cleanup_cross_attention_control(self, model):
|
||||||
|
CrossAttentionControl.clear_attention_editing(model)
|
||||||
|
|
||||||
|
def do_cross_attention_controllable_diffusion_step(self, x, sigma, unconditioning, conditioning, model, model_forward_callback):
|
||||||
|
|
||||||
|
CrossAttentionControl.clear_requests(model)
|
||||||
|
|
||||||
|
if self.edited_conditioning is None:
|
||||||
|
# faster batched path
|
||||||
|
x_twice = torch.cat([x]*2)
|
||||||
|
sigma_twice = torch.cat([sigma]*2)
|
||||||
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||||
|
unconditioned_next_x, conditioned_next_x = model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
|
||||||
|
else:
|
||||||
|
# slower non-batched path (20% slower on mac MPS)
|
||||||
|
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||||
|
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
||||||
|
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
|
||||||
|
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
||||||
|
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||||
|
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||||
|
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||||
|
unconditioned_next_x = model_forward_callback(x, sigma, unconditioning)
|
||||||
|
|
||||||
|
# process x using the original prompt, saving the attention maps
|
||||||
|
CrossAttentionControl.request_save_attention_maps(model)
|
||||||
|
_ = model_forward_callback(x, sigma, cond=conditioning)
|
||||||
|
CrossAttentionControl.clear_requests(model)
|
||||||
|
|
||||||
|
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||||
|
CrossAttentionControl.request_apply_saved_attention_maps(model)
|
||||||
|
conditioned_next_x = model_forward_callback(x, sigma, self.edited_conditioning)
|
||||||
|
CrossAttentionControl.clear_requests(model)
|
||||||
|
|
||||||
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
# adapted from bloc97's CrossAttentionControl colab
|
# adapted from bloc97's CrossAttentionControl colab
|
||||||
# https://github.com/bloc97/CrossAttentionControl
|
# https://github.com/bloc97/CrossAttentionControl
|
||||||
|
|
||||||
@ -27,7 +76,8 @@ class CrossAttentionControl:
|
|||||||
# adapted from init_attention_edit
|
# adapted from init_attention_edit
|
||||||
device = substitute_conditioning.device
|
device = substitute_conditioning.device
|
||||||
|
|
||||||
max_length = model.inner_model.cond_stage_model.max_length
|
# urgh. should this be hardcoded?
|
||||||
|
max_length = 77
|
||||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||||
mask = torch.zeros(max_length)
|
mask = torch.zeros(max_length)
|
||||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||||
|
@ -5,13 +5,23 @@ import numpy as np
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from ldm.invoke.devices import choose_torch_device
|
from ldm.invoke.devices import choose_torch_device
|
||||||
|
from ldm.models.diffusion.cross_attention import CrossAttentionControllableDiffusionMixin
|
||||||
from ldm.models.diffusion.sampler import Sampler
|
from ldm.models.diffusion.sampler import Sampler
|
||||||
from ldm.modules.diffusionmodules.util import noise_like
|
from ldm.modules.diffusionmodules.util import noise_like
|
||||||
|
|
||||||
class DDIMSampler(Sampler):
|
class DDIMSampler(Sampler, CrossAttentionControllableDiffusionMixin):
|
||||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||||
super().__init__(model,schedule,model.num_timesteps,device)
|
super().__init__(model,schedule,model.num_timesteps,device)
|
||||||
|
|
||||||
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
|
super().prepare_to_sample(t_enc, **kwargs)
|
||||||
|
|
||||||
|
edited_conditioning = kwargs.get('edited_conditioning', None)
|
||||||
|
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||||
|
|
||||||
|
self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, edit_opcodes)
|
||||||
|
|
||||||
|
|
||||||
# This is the central routine
|
# This is the central routine
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(
|
def p_sample(
|
||||||
@ -37,12 +47,13 @@ class DDIMSampler(Sampler):
|
|||||||
unconditional_conditioning is None
|
unconditional_conditioning is None
|
||||||
or unconditional_guidance_scale == 1.0
|
or unconditional_guidance_scale == 1.0
|
||||||
):
|
):
|
||||||
|
# damian0815 does not think this code path is ever used
|
||||||
e_t = self.model.apply_model(x, t, c)
|
e_t = self.model.apply_model(x, t, c)
|
||||||
else:
|
else:
|
||||||
x_in = torch.cat([x] * 2)
|
|
||||||
t_in = torch.cat([t] * 2)
|
e_t_uncond, e_t = self.do_cross_attention_controllable_diffusion_step(x, t, unconditional_conditioning, c, self.model,
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
model_forward_callback=lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
|
||||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||||
e_t - e_t_uncond
|
e_t - e_t_uncond
|
||||||
)
|
)
|
||||||
|
@ -13,7 +13,8 @@ from ldm.modules.diffusionmodules.util import (
|
|||||||
noise_like,
|
noise_like,
|
||||||
extract_into_tensor,
|
extract_into_tensor,
|
||||||
)
|
)
|
||||||
from ldm.models.diffusion.cross_attention import CrossAttentionControl
|
from ldm.models.diffusion.cross_attention import CrossAttentionControl, CrossAttentionControllableDiffusionMixin
|
||||||
|
|
||||||
|
|
||||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||||
if threshold <= 0.0:
|
if threshold <= 0.0:
|
||||||
@ -29,53 +30,26 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
|||||||
return torch.clamp(result, min=minval, max=maxval)
|
return torch.clamp(result, min=minval, max=maxval)
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoiser(nn.Module):
|
class CFGDenoiser(nn.Module, CrossAttentionControllableDiffusionMixin):
|
||||||
def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None, edit_opcodes = None):
|
def __init__(self, model, threshold = 0, warmup = 0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.warmup_max = warmup
|
self.warmup_max = warmup
|
||||||
self.warmup = max(warmup / 10, 1)
|
self.warmup = max(warmup / 10, 1)
|
||||||
|
|
||||||
self.edited_conditioning = edited_conditioning
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
|
|
||||||
|
edited_conditioning = kwargs.get('edited_conditioning', None)
|
||||||
|
conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||||
|
|
||||||
|
self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, conditioning_edit_opcodes)
|
||||||
|
|
||||||
if edited_conditioning is not None:
|
|
||||||
# <start> a cat sitting on a car <end>
|
|
||||||
CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes)
|
|
||||||
else:
|
|
||||||
# pass through the attention func but don't act on it
|
|
||||||
CrossAttentionControl.clear_attention_editing(self.inner_model)
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
|
|
||||||
CrossAttentionControl.clear_requests(self.inner_model)
|
unconditioned_next_x, conditioned_next_x = self.do_cross_attention_controllable_diffusion_step(x, sigma, uncond, cond, self.inner_model,
|
||||||
|
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||||
if self.edited_conditioning is None:
|
|
||||||
# faster batch path
|
|
||||||
x_twice = torch.cat([x]*2)
|
|
||||||
sigma_twice = torch.cat([sigma]*2)
|
|
||||||
both_conditionings = torch.cat([uncond, cond])
|
|
||||||
unconditioned_next_x, conditioned_next_x = self.inner_model(x_twice, sigma_twice, cond=both_conditionings).chunk(2)
|
|
||||||
else:
|
|
||||||
# slower non-batched path (20% slower on mac MPS)
|
|
||||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
|
||||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
|
||||||
# This messes app their application later, due to mismatched shape of dim 0 (16 vs. 8)
|
|
||||||
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
|
||||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
|
||||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
|
||||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
|
||||||
unconditioned_next_x = self.inner_model(x, sigma, cond=uncond)
|
|
||||||
|
|
||||||
# process x using the original prompt, saving the attention maps
|
|
||||||
CrossAttentionControl.request_save_attention_maps(self.inner_model)
|
|
||||||
_ = self.inner_model(x, sigma, cond=cond)
|
|
||||||
CrossAttentionControl.clear_requests(self.inner_model)
|
|
||||||
|
|
||||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
|
||||||
CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model)
|
|
||||||
conditioned_next_x = self.inner_model(x, sigma, cond=self.edited_conditioning)
|
|
||||||
CrossAttentionControl.clear_requests(self.inner_model)
|
|
||||||
|
|
||||||
if self.warmup < self.warmup_max:
|
if self.warmup < self.warmup_max:
|
||||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||||
@ -204,7 +178,7 @@ class KSampler(Sampler):
|
|||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
edited_conditioning=None,
|
edited_conditioning=None,
|
||||||
edit_token_index_map=None,
|
conditioning_edit_opcodes=None,
|
||||||
threshold = 0,
|
threshold = 0,
|
||||||
perlin = 0,
|
perlin = 0,
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
@ -236,21 +210,22 @@ class KSampler(Sampler):
|
|||||||
else:
|
else:
|
||||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||||
|
|
||||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10),
|
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||||
edited_conditioning=edited_conditioning, edit_opcodes=edit_token_index_map)
|
model_wrap_cfg.prepare_to_sample(S, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes)
|
||||||
extra_args = {
|
extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': unconditional_guidance_scale,
|
'cond_scale': unconditional_guidance_scale,
|
||||||
}
|
}
|
||||||
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
|
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
|
||||||
return (
|
sampling_result = (
|
||||||
K.sampling.__dict__[f'sample_{self.schedule}'](
|
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||||
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||||
callback=route_callback
|
callback=route_callback
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
return sampling_result
|
||||||
|
|
||||||
# this code will support inpainting if and when ksampler API modified or
|
# this code will support inpainting if and when ksampler API modified or
|
||||||
# a workaround is found.
|
# a workaround is found.
|
||||||
@ -312,7 +287,7 @@ class KSampler(Sampler):
|
|||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def prepare_to_sample(self,t_enc):
|
def prepare_to_sample(self,t_enc,**kwargs):
|
||||||
self.t_enc = t_enc
|
self.t_enc = t_enc
|
||||||
self.model_wrap = None
|
self.model_wrap = None
|
||||||
self.ds = None
|
self.ds = None
|
||||||
@ -323,4 +298,3 @@ class KSampler(Sampler):
|
|||||||
Overrides parent method to return the q_sample of the inner model.
|
Overrides parent method to return the q_sample of the inner model.
|
||||||
'''
|
'''
|
||||||
return self.model.inner_model.q_sample(x0,ts)
|
return self.model.inner_model.q_sample(x0,ts)
|
||||||
|
|
||||||
|
@ -5,14 +5,24 @@ import numpy as np
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from ldm.invoke.devices import choose_torch_device
|
from ldm.invoke.devices import choose_torch_device
|
||||||
|
from ldm.models.diffusion.cross_attention import CrossAttentionControllableDiffusionMixin
|
||||||
from ldm.models.diffusion.sampler import Sampler
|
from ldm.models.diffusion.sampler import Sampler
|
||||||
from ldm.modules.diffusionmodules.util import noise_like
|
from ldm.modules.diffusionmodules.util import noise_like
|
||||||
|
|
||||||
|
|
||||||
class PLMSSampler(Sampler):
|
class PLMSSampler(Sampler, CrossAttentionControllableDiffusionMixin):
|
||||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||||
super().__init__(model,schedule,model.num_timesteps, device)
|
super().__init__(model,schedule,model.num_timesteps, device)
|
||||||
|
|
||||||
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
|
super().prepare_to_sample(t_enc, **kwargs)
|
||||||
|
|
||||||
|
edited_conditioning = kwargs.get('edited_conditioning', None)
|
||||||
|
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||||
|
|
||||||
|
self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, edit_opcodes)
|
||||||
|
|
||||||
|
|
||||||
# this is the essential routine
|
# this is the essential routine
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(
|
def p_sample(
|
||||||
@ -41,14 +51,18 @@ class PLMSSampler(Sampler):
|
|||||||
unconditional_conditioning is None
|
unconditional_conditioning is None
|
||||||
or unconditional_guidance_scale == 1.0
|
or unconditional_guidance_scale == 1.0
|
||||||
):
|
):
|
||||||
|
# damian0815 does not think this code path is ever used
|
||||||
e_t = self.model.apply_model(x, t, c)
|
e_t = self.model.apply_model(x, t, c)
|
||||||
else:
|
else:
|
||||||
x_in = torch.cat([x] * 2)
|
#x_in = torch.cat([x] * 2)
|
||||||
t_in = torch.cat([t] * 2)
|
#t_in = torch.cat([t] * 2)
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
#c_in = torch.cat([unconditional_conditioning, c])
|
||||||
e_t_uncond, e_t = self.model.apply_model(
|
#e_t_uncond, e_t = self.model.apply_model(
|
||||||
x_in, t_in, c_in
|
# x_in, t_in, c_in
|
||||||
).chunk(2)
|
#).chunk(2)
|
||||||
|
e_t_uncond, e_t = self.do_cross_attention_controllable_diffusion_step(x, t, unconditional_conditioning, c, self.model,
|
||||||
|
model_forward_callback=lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||||
|
|
||||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||||
e_t - e_t_uncond
|
e_t - e_t_uncond
|
||||||
)
|
)
|
||||||
|
@ -192,6 +192,7 @@ class Sampler(object):
|
|||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
steps=S,
|
steps=S,
|
||||||
|
**kwargs
|
||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
@ -216,6 +217,7 @@ class Sampler(object):
|
|||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
steps=None,
|
steps=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
time_range = (
|
time_range = (
|
||||||
@ -233,7 +235,7 @@ class Sampler(object):
|
|||||||
dynamic_ncols=True,
|
dynamic_ncols=True,
|
||||||
)
|
)
|
||||||
old_eps = []
|
old_eps = []
|
||||||
self.prepare_to_sample(t_enc=total_steps)
|
self.prepare_to_sample(t_enc=total_steps,**kwargs)
|
||||||
img = self.get_initial_image(x_T,shape,total_steps)
|
img = self.get_initial_image(x_T,shape,total_steps)
|
||||||
|
|
||||||
# probably don't need this at all
|
# probably don't need this at all
|
||||||
@ -323,7 +325,7 @@ class Sampler(object):
|
|||||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||||
x_dec = x_latent
|
x_dec = x_latent
|
||||||
x0 = init_latent
|
x0 = init_latent
|
||||||
self.prepare_to_sample(t_enc=total_steps)
|
self.prepare_to_sample(t_enc=total_steps,**kwargs)
|
||||||
|
|
||||||
for i, step in enumerate(iterator):
|
for i, step in enumerate(iterator):
|
||||||
index = total_steps - i - 1
|
index = total_steps - i - 1
|
||||||
@ -414,5 +416,3 @@ class Sampler(object):
|
|||||||
'''
|
'''
|
||||||
return self.model.q_sample(x0,ts)
|
return self.model.q_sample(x0,ts)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user