wip refactoring shared InvokeAI diffuser mixin to component

This commit is contained in:
Damian at mba 2022-10-19 18:19:55 +02:00
parent 824cb201b1
commit 147d39cb7c
4 changed files with 104 additions and 77 deletions

View File

@ -1,25 +1,32 @@
"""SAMPLING ONLY.""" """SAMPLING ONLY."""
from typing import Union
import torch import torch
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from functools import partial from functools import partial
from ldm.invoke.devices import choose_torch_device from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.cross_attention import CrossAttentionControllableDiffusionMixin from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.models.diffusion.sampler import Sampler from ldm.models.diffusion.sampler import Sampler
from ldm.modules.diffusionmodules.util import noise_like from ldm.modules.diffusionmodules.util import noise_like
class DDIMSampler(Sampler, CrossAttentionControllableDiffusionMixin): class DDIMSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs): def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps,device) super().__init__(model,schedule,model.num_timesteps,device)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def prepare_to_sample(self, t_enc, **kwargs): def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs) super().prepare_to_sample(t_enc, **kwargs)
edited_conditioning = kwargs.get('edited_conditioning', None) edited_conditioning = kwargs.get('edited_conditioning', None)
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, edit_opcodes) if edited_conditioning is not None:
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, edit_opcodes)
else:
self.invokeai_diffuser.cleanup_cross_attention_control()
# This is the central routine # This is the central routine
@ -27,7 +34,7 @@ class DDIMSampler(Sampler, CrossAttentionControllableDiffusionMixin):
def p_sample( def p_sample(
self, self,
x, x,
c, c: Union[torch.Tensor, list],
t, t,
index, index,
repeat_noise=False, repeat_noise=False,
@ -51,12 +58,7 @@ class DDIMSampler(Sampler, CrossAttentionControllableDiffusionMixin):
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
e_t_uncond, e_t = self.do_cross_attention_controllable_diffusion_step(x, t, unconditional_conditioning, c, self.model, e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale)
model_forward_callback=lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == 'eps' assert self.model.parameterization == 'eps'

View File

@ -1,19 +1,11 @@
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers""" """wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
from enum import Enum
import k_diffusion as K import k_diffusion as K
import torch import torch
import torch.nn as nn from torch import nn
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.sampler import Sampler from .sampler import Sampler
from ldm.util import rand_perlin_2d from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
from ldm.models.diffusion.cross_attention import CrossAttentionControl, CrossAttentionControllableDiffusionMixin
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
@ -30,27 +22,32 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
return torch.clamp(result, min=minval, max=maxval) return torch.clamp(result, min=minval, max=maxval)
class CFGDenoiser(nn.Module, CrossAttentionControllableDiffusionMixin): class CFGDenoiser(nn.Module):
def __init__(self, model, threshold = 0, warmup = 0): def __init__(self, model, threshold = 0, warmup = 0):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
self.threshold = threshold self.threshold = threshold
self.warmup_max = warmup self.warmup_max = warmup
self.warmup = max(warmup / 10, 1) self.warmup = max(warmup / 10, 1)
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
def prepare_to_sample(self, t_enc, **kwargs): def prepare_to_sample(self, t_enc, **kwargs):
edited_conditioning = kwargs.get('edited_conditioning', None) edited_conditioning = kwargs.get('edited_conditioning', None)
conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
self.setup_cross_attention_control_if_appropriate(self.inner_model, edited_conditioning, conditioning_edit_opcodes) if edited_conditioning is not None:
conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, conditioning_edit_opcodes)
else:
self.invokeai_diffuser.cleanup_cross_attention_control()
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
unconditioned_next_x, conditioned_next_x = self.do_cross_attention_controllable_diffusion_step(x, sigma, uncond, cond, self.inner_model, final_next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
# apply threshold
if self.warmup < self.warmup_max: if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1 self.warmup += 1
@ -58,9 +55,8 @@ class CFGDenoiser(nn.Module, CrossAttentionControllableDiffusionMixin):
thresh = self.threshold thresh = self.threshold
if thresh > self.threshold: if thresh > self.threshold:
thresh = self.threshold thresh = self.threshold
# to scale how much effect conditioning has, calculate the changes it does and then scale that return cfg_apply_threshold(final_next_x, thresh)
scaled_delta = (conditioned_next_x - unconditioned_next_x) * cond_scale
return cfg_apply_threshold(unconditioned_next_x + scaled_delta, thresh)
class KSampler(Sampler): class KSampler(Sampler):
@ -75,16 +71,6 @@ class KSampler(Sampler):
self.ds = None self.ds = None
self.s_in = None self.s_in = None
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(
x_in, sigma_in, cond=cond_in
).chunk(2)
return uncond + (cond - uncond) * cond_scale
def make_schedule( def make_schedule(
self, self,
ddim_num_steps, ddim_num_steps,
@ -303,3 +289,4 @@ class KSampler(Sampler):
Overrides parent method to return the q_sample of the inner model. Overrides parent method to return the q_sample of the inner model.
''' '''
return self.model.inner_model.q_sample(x0,ts) return self.model.inner_model.q_sample(x0,ts)

View File

@ -5,22 +5,28 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from functools import partial from functools import partial
from ldm.invoke.devices import choose_torch_device from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.cross_attention import CrossAttentionControllableDiffusionMixin from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.models.diffusion.sampler import Sampler from ldm.models.diffusion.sampler import Sampler
from ldm.modules.diffusionmodules.util import noise_like from ldm.modules.diffusionmodules.util import noise_like
class PLMSSampler(Sampler, CrossAttentionControllableDiffusionMixin): class PLMSSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs): def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps, device) super().__init__(model,schedule,model.num_timesteps, device)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def prepare_to_sample(self, t_enc, **kwargs): def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs) super().prepare_to_sample(t_enc, **kwargs)
edited_conditioning = kwargs.get('edited_conditioning', None) edited_conditioning = kwargs.get('edited_conditioning', None)
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, edit_opcodes) if edited_conditioning is not None:
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, edit_opcodes)
else:
self.invokeai_diffuser.cleanup_cross_attention_control()
# this is the essential routine # this is the essential routine
@ -51,21 +57,11 @@ class PLMSSampler(Sampler, CrossAttentionControllableDiffusionMixin):
unconditional_conditioning is None unconditional_conditioning is None
or unconditional_guidance_scale == 1.0 or unconditional_guidance_scale == 1.0
): ):
# damian0815 does not think this code path is ever used # damian0815 does not know if this code path is ever used
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
#x_in = torch.cat([x] * 2)
#t_in = torch.cat([t] * 2)
#c_in = torch.cat([unconditional_conditioning, c])
#e_t_uncond, e_t = self.model.apply_model(
# x_in, t_in, c_in
#).chunk(2)
e_t_uncond, e_t = self.do_cross_attention_controllable_diffusion_step(x, t, unconditional_conditioning, c, self.model,
model_forward_callback=lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
e_t = e_t_uncond + unconditional_guidance_scale * ( e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale)
e_t - e_t_uncond
)
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == 'eps' assert self.model.parameterization == 'eps'

View File

@ -1,33 +1,70 @@
from enum import Enum from enum import Enum
from typing import Callable
import torch import torch
class InvokeAIDiffuserComponent:
class CrossAttentionControllableDiffusionMixin: class Conditioning:
def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None):
"""
:param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768)
:param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens
"""
#self.conditioning = conditioning
#self.unconditioning = unconditioning
self.edited_conditioning = edited_conditioning
self.edit_opcodes = edit_opcodes
def setup_cross_attention_control_if_appropriate(self, model, edited_conditioning, edit_opcodes): '''
The aim of this component is to provide a single place for code that can be applied identically to
all InvokeAI diffusion procedures.
At the moment it includes the following features:
* Cross Attention Control ("prompt2prompt")
'''
def __init__(self, model, model_forward_callback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]):
"""
:param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
"""
self.model = model
self.model_forward_callback = model_forward_callback
def setup_cross_attention_control(self, edited_conditioning, edit_opcodes):
self.edited_conditioning = edited_conditioning self.edited_conditioning = edited_conditioning
CrossAttentionControl.setup_attention_editing(self.model, edited_conditioning, edit_opcodes)
if edited_conditioning is not None: def cleanup_cross_attention_control(self):
# <start> a cat sitting on a car <end> self.edited_conditioning = None
CrossAttentionControl.setup_attention_editing(model, edited_conditioning, edit_opcodes) CrossAttentionControl.clear_attention_editing(self.model)
else:
# pass through the attention func but don't act on it
CrossAttentionControl.clear_attention_editing(model)
def cleanup_cross_attention_control(self, model):
CrossAttentionControl.clear_attention_editing(model)
def do_cross_attention_controllable_diffusion_step(self, x, sigma, unconditioning, conditioning, model, model_forward_callback): def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: torch.Tensor, conditioning: torch.Tensor,
unconditional_guidance_scale: float):
"""
:param x: Current latents
:param sigma: aka t, passed to the internal model to control how much denoising will occur
:param unconditioning: [B x 77 x 768] embeddings for unconditioned output
:param conditioning: [B x 77 x 768] embeddings for conditioned output
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
:param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
:return: the new latents after applying the model to x using unconditioning and CFG-scaled conditioning.
"""
CrossAttentionControl.clear_requests(model) CrossAttentionControl.clear_requests(self.model)
if self.edited_conditioning is None: if self.edited_conditioning is None:
# faster batched path # faster batched path
x_twice = torch.cat([x]*2) x_twice = torch.cat([x]*2)
sigma_twice = torch.cat([sigma]*2) sigma_twice = torch.cat([sigma]*2)
both_conditionings = torch.cat([unconditioning, conditioning]) both_conditionings = torch.cat([unconditioning, conditioning])
unconditioned_next_x, conditioned_next_x = model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
else: else:
# slower non-batched path (20% slower on mac MPS) # slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
@ -37,19 +74,24 @@ class CrossAttentionControllableDiffusionMixin:
# representing batched uncond + cond, but then when it comes to applying the saved attention, the # representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
unconditioned_next_x = model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
# process x using the original prompt, saving the attention maps # process x using the original prompt, saving the attention maps
CrossAttentionControl.request_save_attention_maps(model) CrossAttentionControl.request_save_attention_maps(self.model)
_ = model_forward_callback(x, sigma, cond=conditioning) _ = self.model_forward_callback(x, sigma, cond=conditioning)
CrossAttentionControl.clear_requests(model) CrossAttentionControl.clear_requests(self.model)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied # process x again, using the saved attention maps to control where self.edited_conditioning will be applied
CrossAttentionControl.request_apply_saved_attention_maps(model) CrossAttentionControl.request_apply_saved_attention_maps(self.model)
conditioned_next_x = model_forward_callback(x, sigma, self.edited_conditioning) conditioned_next_x = self.model_forward_callback(x, sigma, self.edited_conditioning)
CrossAttentionControl.clear_requests(model) CrossAttentionControl.clear_requests(model)
return unconditioned_next_x, conditioned_next_x
# to scale how much effect conditioning has, calculate the changes it does and then scale that
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
combined_next_x = unconditioned_next_x + scaled_delta
return combined_next_x
# adapted from bloc97's CrossAttentionControl colab # adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl # https://github.com/bloc97/CrossAttentionControl