mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip refactoring shared InvokeAI diffuser mixin to component
This commit is contained in:
parent
824cb201b1
commit
147d39cb7c
@ -1,25 +1,32 @@
|
||||
"""SAMPLING ONLY."""
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.cross_attention import CrossAttentionControllableDiffusionMixin
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
|
||||
class DDIMSampler(Sampler, CrossAttentionControllableDiffusionMixin):
|
||||
class DDIMSampler(Sampler):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__(model,schedule,model.num_timesteps,device)
|
||||
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
edited_conditioning = kwargs.get('edited_conditioning', None)
|
||||
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||
|
||||
self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, edit_opcodes)
|
||||
if edited_conditioning is not None:
|
||||
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, edit_opcodes)
|
||||
else:
|
||||
self.invokeai_diffuser.cleanup_cross_attention_control()
|
||||
|
||||
|
||||
# This is the central routine
|
||||
@ -27,7 +34,7 @@ class DDIMSampler(Sampler, CrossAttentionControllableDiffusionMixin):
|
||||
def p_sample(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
c: Union[torch.Tensor, list],
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
@ -51,12 +58,7 @@ class DDIMSampler(Sampler, CrossAttentionControllableDiffusionMixin):
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
|
||||
e_t_uncond, e_t = self.do_cross_attention_controllable_diffusion_step(x, t, unconditional_conditioning, c, self.model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t - e_t_uncond
|
||||
)
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
|
@ -1,19 +1,11 @@
|
||||
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
|
||||
from enum import Enum
|
||||
|
||||
import k_diffusion as K
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.util import rand_perlin_2d
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
extract_into_tensor,
|
||||
)
|
||||
from ldm.models.diffusion.cross_attention import CrossAttentionControl, CrossAttentionControllableDiffusionMixin
|
||||
from torch import nn
|
||||
|
||||
from .sampler import Sampler
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
@ -30,27 +22,32 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
return torch.clamp(result, min=minval, max=maxval)
|
||||
|
||||
|
||||
class CFGDenoiser(nn.Module, CrossAttentionControllableDiffusionMixin):
|
||||
class CFGDenoiser(nn.Module):
|
||||
def __init__(self, model, threshold = 0, warmup = 0):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.threshold = threshold
|
||||
self.warmup_max = warmup
|
||||
self.warmup = max(warmup / 10, 1)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
|
||||
edited_conditioning = kwargs.get('edited_conditioning', None)
|
||||
conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||
|
||||
self.setup_cross_attention_control_if_appropriate(self.inner_model, edited_conditioning, conditioning_edit_opcodes)
|
||||
if edited_conditioning is not None:
|
||||
conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, conditioning_edit_opcodes)
|
||||
else:
|
||||
self.invokeai_diffuser.cleanup_cross_attention_control()
|
||||
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
|
||||
unconditioned_next_x, conditioned_next_x = self.do_cross_attention_controllable_diffusion_step(x, sigma, uncond, cond, self.inner_model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||
final_next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
|
||||
# apply threshold
|
||||
if self.warmup < self.warmup_max:
|
||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||
self.warmup += 1
|
||||
@ -58,9 +55,8 @@ class CFGDenoiser(nn.Module, CrossAttentionControllableDiffusionMixin):
|
||||
thresh = self.threshold
|
||||
if thresh > self.threshold:
|
||||
thresh = self.threshold
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * cond_scale
|
||||
return cfg_apply_threshold(unconditioned_next_x + scaled_delta, thresh)
|
||||
return cfg_apply_threshold(final_next_x, thresh)
|
||||
|
||||
|
||||
|
||||
class KSampler(Sampler):
|
||||
@ -75,16 +71,6 @@ class KSampler(Sampler):
|
||||
self.ds = None
|
||||
self.s_in = None
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(
|
||||
x_in, sigma_in, cond=cond_in
|
||||
).chunk(2)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
|
||||
def make_schedule(
|
||||
self,
|
||||
ddim_num_steps,
|
||||
@ -303,3 +289,4 @@ class KSampler(Sampler):
|
||||
Overrides parent method to return the q_sample of the inner model.
|
||||
'''
|
||||
return self.model.inner_model.q_sample(x0,ts)
|
||||
|
||||
|
@ -5,22 +5,28 @@ import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.cross_attention import CrossAttentionControllableDiffusionMixin
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
|
||||
|
||||
class PLMSSampler(Sampler, CrossAttentionControllableDiffusionMixin):
|
||||
class PLMSSampler(Sampler):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__(model,schedule,model.num_timesteps, device)
|
||||
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
edited_conditioning = kwargs.get('edited_conditioning', None)
|
||||
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||
|
||||
self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, edit_opcodes)
|
||||
if edited_conditioning is not None:
|
||||
edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, edit_opcodes)
|
||||
else:
|
||||
self.invokeai_diffuser.cleanup_cross_attention_control()
|
||||
|
||||
|
||||
# this is the essential routine
|
||||
@ -51,21 +57,11 @@ class PLMSSampler(Sampler, CrossAttentionControllableDiffusionMixin):
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
# damian0815 does not think this code path is ever used
|
||||
# damian0815 does not know if this code path is ever used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
#x_in = torch.cat([x] * 2)
|
||||
#t_in = torch.cat([t] * 2)
|
||||
#c_in = torch.cat([unconditional_conditioning, c])
|
||||
#e_t_uncond, e_t = self.model.apply_model(
|
||||
# x_in, t_in, c_in
|
||||
#).chunk(2)
|
||||
e_t_uncond, e_t = self.do_cross_attention_controllable_diffusion_step(x, t, unconditional_conditioning, c, self.model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t - e_t_uncond
|
||||
)
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
|
@ -1,33 +1,70 @@
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
|
||||
class CrossAttentionControllableDiffusionMixin:
|
||||
class Conditioning:
|
||||
def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None):
|
||||
"""
|
||||
:param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768)
|
||||
:param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens
|
||||
"""
|
||||
#self.conditioning = conditioning
|
||||
#self.unconditioning = unconditioning
|
||||
self.edited_conditioning = edited_conditioning
|
||||
self.edit_opcodes = edit_opcodes
|
||||
|
||||
def setup_cross_attention_control_if_appropriate(self, model, edited_conditioning, edit_opcodes):
|
||||
'''
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
all InvokeAI diffusion procedures.
|
||||
|
||||
At the moment it includes the following features:
|
||||
* Cross Attention Control ("prompt2prompt")
|
||||
'''
|
||||
|
||||
def __init__(self, model, model_forward_callback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]):
|
||||
"""
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
"""
|
||||
self.model = model
|
||||
self.model_forward_callback = model_forward_callback
|
||||
|
||||
|
||||
def setup_cross_attention_control(self, edited_conditioning, edit_opcodes):
|
||||
self.edited_conditioning = edited_conditioning
|
||||
CrossAttentionControl.setup_attention_editing(self.model, edited_conditioning, edit_opcodes)
|
||||
|
||||
if edited_conditioning is not None:
|
||||
# <start> a cat sitting on a car <end>
|
||||
CrossAttentionControl.setup_attention_editing(model, edited_conditioning, edit_opcodes)
|
||||
else:
|
||||
# pass through the attention func but don't act on it
|
||||
CrossAttentionControl.clear_attention_editing(model)
|
||||
def cleanup_cross_attention_control(self):
|
||||
self.edited_conditioning = None
|
||||
CrossAttentionControl.clear_attention_editing(self.model)
|
||||
|
||||
def cleanup_cross_attention_control(self, model):
|
||||
CrossAttentionControl.clear_attention_editing(model)
|
||||
|
||||
def do_cross_attention_controllable_diffusion_step(self, x, sigma, unconditioning, conditioning, model, model_forward_callback):
|
||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||
unconditioning: torch.Tensor, conditioning: torch.Tensor,
|
||||
unconditional_guidance_scale: float):
|
||||
"""
|
||||
:param x: Current latents
|
||||
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
||||
:param unconditioning: [B x 77 x 768] embeddings for unconditioned output
|
||||
:param conditioning: [B x 77 x 768] embeddings for conditioned output
|
||||
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
:return: the new latents after applying the model to x using unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
CrossAttentionControl.clear_requests(model)
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
if self.edited_conditioning is None:
|
||||
# faster batched path
|
||||
x_twice = torch.cat([x]*2)
|
||||
sigma_twice = torch.cat([sigma]*2)
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
unconditioned_next_x, conditioned_next_x = model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
|
||||
else:
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
@ -37,19 +74,24 @@ class CrossAttentionControllableDiffusionMixin:
|
||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||
unconditioned_next_x = model_forward_callback(x, sigma, unconditioning)
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
CrossAttentionControl.request_save_attention_maps(model)
|
||||
_ = model_forward_callback(x, sigma, cond=conditioning)
|
||||
CrossAttentionControl.clear_requests(model)
|
||||
CrossAttentionControl.request_save_attention_maps(self.model)
|
||||
_ = self.model_forward_callback(x, sigma, cond=conditioning)
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
CrossAttentionControl.request_apply_saved_attention_maps(model)
|
||||
conditioned_next_x = model_forward_callback(x, sigma, self.edited_conditioning)
|
||||
CrossAttentionControl.request_apply_saved_attention_maps(self.model)
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, self.edited_conditioning)
|
||||
CrossAttentionControl.clear_requests(model)
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
|
||||
combined_next_x = unconditioned_next_x + scaled_delta
|
||||
|
||||
return combined_next_x
|
||||
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
Loading…
Reference in New Issue
Block a user