mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip refactoring shared InvokeAI diffuser mixin to component
This commit is contained in:
@ -1,19 +1,11 @@
|
||||
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
|
||||
from enum import Enum
|
||||
|
||||
import k_diffusion as K
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.util import rand_perlin_2d
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
extract_into_tensor,
|
||||
)
|
||||
from ldm.models.diffusion.cross_attention import CrossAttentionControl, CrossAttentionControllableDiffusionMixin
|
||||
from torch import nn
|
||||
|
||||
from .sampler import Sampler
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
@ -30,27 +22,32 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
return torch.clamp(result, min=minval, max=maxval)
|
||||
|
||||
|
||||
class CFGDenoiser(nn.Module, CrossAttentionControllableDiffusionMixin):
|
||||
class CFGDenoiser(nn.Module):
|
||||
def __init__(self, model, threshold = 0, warmup = 0):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.threshold = threshold
|
||||
self.warmup_max = warmup
|
||||
self.warmup = max(warmup / 10, 1)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
|
||||
edited_conditioning = kwargs.get('edited_conditioning', None)
|
||||
conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||
|
||||
self.setup_cross_attention_control_if_appropriate(self.inner_model, edited_conditioning, conditioning_edit_opcodes)
|
||||
if edited_conditioning is not None:
|
||||
conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None)
|
||||
self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, conditioning_edit_opcodes)
|
||||
else:
|
||||
self.invokeai_diffuser.cleanup_cross_attention_control()
|
||||
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
|
||||
unconditioned_next_x, conditioned_next_x = self.do_cross_attention_controllable_diffusion_step(x, sigma, uncond, cond, self.inner_model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||
final_next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
|
||||
# apply threshold
|
||||
if self.warmup < self.warmup_max:
|
||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||
self.warmup += 1
|
||||
@ -58,9 +55,8 @@ class CFGDenoiser(nn.Module, CrossAttentionControllableDiffusionMixin):
|
||||
thresh = self.threshold
|
||||
if thresh > self.threshold:
|
||||
thresh = self.threshold
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * cond_scale
|
||||
return cfg_apply_threshold(unconditioned_next_x + scaled_delta, thresh)
|
||||
return cfg_apply_threshold(final_next_x, thresh)
|
||||
|
||||
|
||||
|
||||
class KSampler(Sampler):
|
||||
@ -75,16 +71,6 @@ class KSampler(Sampler):
|
||||
self.ds = None
|
||||
self.s_in = None
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(
|
||||
x_in, sigma_in, cond=cond_in
|
||||
).chunk(2)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
|
||||
def make_schedule(
|
||||
self,
|
||||
ddim_num_steps,
|
||||
@ -303,3 +289,4 @@ class KSampler(Sampler):
|
||||
Overrides parent method to return the q_sample of the inner model.
|
||||
'''
|
||||
return self.model.inner_model.q_sample(x0,ts)
|
||||
|
||||
|
Reference in New Issue
Block a user