runs but doesn't work properly - see below for test prompt

test prompt:
"a cat sitting on a car {a dog sitting on a car}" -W 384 -H 256 -s 10 -S 12346 -A k_euler
note that substition of dog for cat is currently hard-coded (ksampler.py
	line 43-44)
This commit is contained in:
Damian at mba
2022-10-16 20:39:47 +02:00
parent 33d6603fef
commit 8ff507b03b
8 changed files with 207 additions and 199 deletions

View File

@ -13,6 +13,7 @@ from ldm.modules.diffusionmodules.util import (
noise_like,
extract_into_tensor,
)
from ldm.models.diffusion.cross_attention import CrossAttentionControl
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
if threshold <= 0.0:
@ -29,21 +30,41 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
class CFGDenoiser(nn.Module):
def __init__(self, model, threshold = 0, warmup = 0):
def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None):
super().__init__()
self.inner_model = model
self.threshold = threshold
self.warmup_max = warmup
self.warmup = max(warmup / 10, 1)
self.edited_conditioning = edited_conditioning
if self.edited_conditioning is not None:
initial_tokens_count = 77 # '<start> a cat sitting on a car <end>'
token_indices_to_edit = [2] # 'cat'
CrossAttentionControl.setup_attention_editing(self.inner_model, initial_tokens_count, edited_conditioning, token_indices_to_edit)
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
module = self.get_attention_module(AttentionLayer.TOKENS)
print('generating new unconditioned latents')
unconditioned_latents = self.inner_model(x, sigma, cond=uncond)
# process x using the original prompt, saving the attention maps if required
if self.edited_conditioning is not None:
# this is automatically toggled off after the model forward()
CrossAttentionControl.request_save_attention_maps(self.inner_model)
print('generating new conditioned latents')
conditioned_latents = self.inner_model(x, sigma, cond=cond)
if self.edited_conditioning is not None:
# process x again, using the saved attention maps but the new conditioning
# this is automatically toggled off after the model forward()
CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model)
print('generating edited conditioned latents')
conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning)
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
@ -52,7 +73,8 @@ class CFGDenoiser(nn.Module):
thresh = self.threshold
if thresh > self.threshold:
thresh = self.threshold
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh)
delta = (conditioned_latents - unconditioned_latents)
return cfg_apply_threshold(unconditioned_latents + delta * cond_scale, thresh)
class KSampler(Sampler):
@ -169,6 +191,7 @@ class KSampler(Sampler):
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
edited_conditioning=None,
threshold = 0,
perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -200,7 +223,7 @@ class KSampler(Sampler):
else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10), edited_conditioning=edited_conditioning)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,