mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip bringing cross-attention to PLMS and DDIM
This commit is contained in:
@ -13,7 +13,8 @@ from ldm.modules.diffusionmodules.util import (
|
||||
noise_like,
|
||||
extract_into_tensor,
|
||||
)
|
||||
from ldm.models.diffusion.cross_attention import CrossAttentionControl
|
||||
from ldm.models.diffusion.cross_attention import CrossAttentionControl, CrossAttentionControllableDiffusionMixin
|
||||
|
||||
|
||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
if threshold <= 0.0:
|
||||
@ -29,53 +30,26 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
return torch.clamp(result, min=minval, max=maxval)
|
||||
|
||||
|
||||
class CFGDenoiser(nn.Module):
|
||||
def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None, edit_opcodes = None):
|
||||
class CFGDenoiser(nn.Module, CrossAttentionControllableDiffusionMixin):
|
||||
def __init__(self, model, threshold = 0, warmup = 0):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.threshold = threshold
|
||||
self.warmup_max = warmup
|
||||
self.warmup = max(warmup / 10, 1)
|
||||
|
||||
self.edited_conditioning = edited_conditioning
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
|
||||
edited_conditioning = kwargs.get('edited_conditioning', None)
|
||||
conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||
|
||||
self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, conditioning_edit_opcodes)
|
||||
|
||||
if edited_conditioning is not None:
|
||||
# <start> a cat sitting on a car <end>
|
||||
CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes)
|
||||
else:
|
||||
# pass through the attention func but don't act on it
|
||||
CrossAttentionControl.clear_attention_editing(self.inner_model)
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
|
||||
CrossAttentionControl.clear_requests(self.inner_model)
|
||||
|
||||
if self.edited_conditioning is None:
|
||||
# faster batch path
|
||||
x_twice = torch.cat([x]*2)
|
||||
sigma_twice = torch.cat([sigma]*2)
|
||||
both_conditionings = torch.cat([uncond, cond])
|
||||
unconditioned_next_x, conditioned_next_x = self.inner_model(x_twice, sigma_twice, cond=both_conditionings).chunk(2)
|
||||
else:
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
||||
# This messes app their application later, due to mismatched shape of dim 0 (16 vs. 8)
|
||||
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||
unconditioned_next_x = self.inner_model(x, sigma, cond=uncond)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
CrossAttentionControl.request_save_attention_maps(self.inner_model)
|
||||
_ = self.inner_model(x, sigma, cond=cond)
|
||||
CrossAttentionControl.clear_requests(self.inner_model)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model)
|
||||
conditioned_next_x = self.inner_model(x, sigma, cond=self.edited_conditioning)
|
||||
CrossAttentionControl.clear_requests(self.inner_model)
|
||||
unconditioned_next_x, conditioned_next_x = self.do_cross_attention_controllable_diffusion_step(x, sigma, uncond, cond, self.inner_model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||
|
||||
if self.warmup < self.warmup_max:
|
||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||
@ -204,7 +178,7 @@ class KSampler(Sampler):
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
edited_conditioning=None,
|
||||
edit_token_index_map=None,
|
||||
conditioning_edit_opcodes=None,
|
||||
threshold = 0,
|
||||
perlin = 0,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
@ -236,21 +210,22 @@ class KSampler(Sampler):
|
||||
else:
|
||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10),
|
||||
edited_conditioning=edited_conditioning, edit_opcodes=edit_token_index_map)
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
model_wrap_cfg.prepare_to_sample(S, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes)
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': unconditional_guidance_scale,
|
||||
}
|
||||
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
|
||||
return (
|
||||
sampling_result = (
|
||||
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||
callback=route_callback
|
||||
),
|
||||
None,
|
||||
)
|
||||
return sampling_result
|
||||
|
||||
# this code will support inpainting if and when ksampler API modified or
|
||||
# a workaround is found.
|
||||
@ -312,7 +287,7 @@ class KSampler(Sampler):
|
||||
else:
|
||||
return x
|
||||
|
||||
def prepare_to_sample(self,t_enc):
|
||||
def prepare_to_sample(self,t_enc,**kwargs):
|
||||
self.t_enc = t_enc
|
||||
self.model_wrap = None
|
||||
self.ds = None
|
||||
@ -323,4 +298,3 @@ class KSampler(Sampler):
|
||||
Overrides parent method to return the q_sample of the inner model.
|
||||
'''
|
||||
return self.model.inner_model.q_sample(x0,ts)
|
||||
|
||||
|
Reference in New Issue
Block a user