mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
bring in attention etc.
This commit is contained in:
commit
4c1267338b
@ -527,7 +527,7 @@ def parameters_to_generated_image_metadata(parameters):
|
||||
rfc_dict["sampler"] = parameters["sampler_name"]
|
||||
|
||||
# display weighted subprompts (liable to change)
|
||||
subprompts = split_weighted_subprompts(parameters["prompt"])
|
||||
subprompts = split_weighted_subprompts(parameters["prompt"], skip_normalize=True)
|
||||
subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts]
|
||||
rfc_dict["prompt"] = subprompts
|
||||
|
||||
|
@ -76,4 +76,4 @@ model:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder
|
||||
|
@ -97,7 +97,8 @@ def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: Fl
|
||||
if type(flattened_prompt) is not FlattenedPrompt:
|
||||
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
|
||||
fragments = [x.text for x in flattened_prompt.children]
|
||||
embeddings, tokens = model.get_learned_conditioning([' '.join(fragments)], return_tokens=True)
|
||||
weights = [x.weight for x in flattened_prompt.children]
|
||||
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
|
||||
return embeddings, tokens
|
||||
|
||||
|
||||
|
@ -329,7 +329,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
||||
|
||||
def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
|
||||
fragment_string = x[0]
|
||||
print(f"parsing fragment string \"{fragment_string}\"")
|
||||
#print(f"parsing fragment string \"{fragment_string}\"")
|
||||
if len(fragment_string.strip()) == 0:
|
||||
return Fragment('')
|
||||
|
||||
|
@ -183,7 +183,6 @@ class KSampler(Sampler):
|
||||
)
|
||||
|
||||
# sigmas are set up in make_schedule - we take the last steps items
|
||||
total_steps = len(self.sigmas)
|
||||
sigmas = self.sigmas[-S-1:]
|
||||
|
||||
# x_T is variation noise. When an init image is provided (in x0) we need to add
|
||||
|
@ -4,7 +4,7 @@ ldm.models.diffusion.sampler
|
||||
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
|
||||
|
||||
'''
|
||||
from enum import Enum
|
||||
from math import ceil
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
@ -1,4 +1,5 @@
|
||||
from enum import Enum
|
||||
from math import ceil
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
@ -104,6 +105,58 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
return combined_next_x
|
||||
|
||||
|
||||
# todo: make this work
|
||||
@classmethod
|
||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
|
||||
|
||||
# below is fugly omg
|
||||
num_actual_conditionings = len(c_or_weighted_c_list)
|
||||
conditionings = [uc] + [c for c,weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c,weight in weighted_cond_list]
|
||||
chunk_count = ceil(len(conditionings)/2)
|
||||
deltas = None
|
||||
for chunk_index in range(chunk_count):
|
||||
offset = chunk_index*2
|
||||
chunk_size = min(2, len(conditionings)-offset)
|
||||
|
||||
if chunk_size == 1:
|
||||
c_in = conditionings[offset]
|
||||
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
||||
latents_b = None
|
||||
else:
|
||||
c_in = torch.cat(conditionings[offset:offset+2])
|
||||
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
||||
|
||||
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
||||
if chunk_index == 0:
|
||||
uncond_latents = latents_a
|
||||
deltas = latents_b - uncond_latents
|
||||
else:
|
||||
deltas = torch.cat((deltas, latents_a - uncond_latents))
|
||||
if latents_b is not None:
|
||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||
|
||||
# merge the weighted deltas together into a single merged delta
|
||||
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||
normalize = False
|
||||
if normalize:
|
||||
per_delta_weights /= torch.sum(per_delta_weights)
|
||||
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||
|
||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
||||
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
||||
|
||||
return uncond_latents + deltas_merged * global_guidance_scale
|
||||
|
||||
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
|
@ -440,12 +440,6 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
|
||||
def forward(self, text, **kwargs):
|
||||
|
||||
should_return_tokens = False
|
||||
if 'return_tokens' in kwargs:
|
||||
should_return_tokens = kwargs.get('return_tokens', False)
|
||||
# self.transformer doesn't like having extra kwargs
|
||||
kwargs.pop('return_tokens')
|
||||
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
@ -458,22 +452,211 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
tokens = batch_encoding['input_ids'].to(self.device)
|
||||
z = self.transformer(input_ids=tokens, **kwargs)
|
||||
|
||||
if should_return_tokens:
|
||||
return z, tokens
|
||||
else:
|
||||
return z
|
||||
return z
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
return self(text, **kwargs)
|
||||
|
||||
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
|
||||
fragment_weights_key = "fragment_weights"
|
||||
return_tokens_key = "return_tokens"
|
||||
|
||||
def forward(self, text: list, **kwargs):
|
||||
'''
|
||||
|
||||
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
|
||||
weights shall be applied.
|
||||
:param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights
|
||||
for the prompt fragments. In this case text must contain batches of lists of prompt fragments.
|
||||
:return: A tensor of shape (B, 77, 768) containing weighted embeddings
|
||||
'''
|
||||
if self.fragment_weights_key not in kwargs:
|
||||
# fallback to base class implementation
|
||||
return super().forward(text, **kwargs)
|
||||
|
||||
fragment_weights = kwargs[self.fragment_weights_key]
|
||||
# self.transformer doesn't like receiving "fragment_weights" as an argument
|
||||
kwargs.pop(self.fragment_weights_key)
|
||||
|
||||
should_return_tokens = False
|
||||
if self.return_tokens_key in kwargs:
|
||||
should_return_tokens = kwargs.get(self.return_tokens_key, False)
|
||||
# self.transformer doesn't like having extra kwargs
|
||||
kwargs.pop(self.return_tokens_key)
|
||||
|
||||
batch_z = None
|
||||
batch_tokens = None
|
||||
for fragments, weights in zip(text, fragment_weights):
|
||||
|
||||
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
|
||||
# applying a multiplier to the CFG scale on a per-token basis).
|
||||
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
|
||||
# captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
|
||||
# interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
|
||||
# "red" is to tell SD that it should almost completely *ignore* redness).
|
||||
# To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
|
||||
# string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
|
||||
# closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.
|
||||
|
||||
# handle weights >=1
|
||||
tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights)
|
||||
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
|
||||
|
||||
# this is our starting point
|
||||
embeddings = base_embedding.unsqueeze(0)
|
||||
per_embedding_weights = [1.0]
|
||||
|
||||
# now handle weights <1
|
||||
# Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
|
||||
# with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
|
||||
# embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
|
||||
# removed.
|
||||
# eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
|
||||
# for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
|
||||
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
|
||||
for index, fragment_weight in enumerate(weights):
|
||||
if fragment_weight < 1:
|
||||
fragments_without_this = fragments[:index] + fragments[index+1:]
|
||||
weights_without_this = weights[:index] + weights[index+1:]
|
||||
tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this)
|
||||
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
|
||||
|
||||
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
|
||||
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
|
||||
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
|
||||
# therefore:
|
||||
# fragment_weight = 1: we are at base_z => lerp weight 0
|
||||
# fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
|
||||
# fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
|
||||
# so let's use tan(), because:
|
||||
# tan is 0.0 at 0,
|
||||
# 1.0 at PI/4, and
|
||||
# inf at PI/2
|
||||
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
|
||||
epsilon = 1e-9
|
||||
fragment_weight = max(epsilon, fragment_weight) # inf is bad
|
||||
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
|
||||
# todo handle negative weight?
|
||||
|
||||
per_embedding_weights.append(embedding_lerp_weight)
|
||||
|
||||
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
|
||||
|
||||
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
||||
|
||||
# append to batch
|
||||
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
|
||||
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
|
||||
|
||||
# should have shape (B, 77, 768)
|
||||
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
|
||||
|
||||
if should_return_tokens:
|
||||
return batch_z, batch_tokens
|
||||
else:
|
||||
return batch_z
|
||||
|
||||
@classmethod
|
||||
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
|
||||
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
|
||||
if normalize:
|
||||
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
|
||||
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
|
||||
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
|
||||
return torch.sum(embeddings * reshaped_weights, dim=1)
|
||||
# lerped embeddings has shape (77, 768)
|
||||
|
||||
|
||||
def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor):
|
||||
'''
|
||||
|
||||
:param fragments:
|
||||
:param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine.
|
||||
:return:
|
||||
'''
|
||||
# empty is meaningful
|
||||
if len(fragments) == 0 and len(weights) == 0:
|
||||
fragments = ['']
|
||||
weights = [1]
|
||||
item_encodings = self.tokenizer(
|
||||
fragments,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_overflowing_tokens=False,
|
||||
padding='do_not_pad',
|
||||
return_tensors=None, # just give me a list of ints
|
||||
)['input_ids']
|
||||
all_tokens = []
|
||||
per_token_weights = []
|
||||
#print("all fragments:", fragments, weights)
|
||||
for index, fragment in enumerate(item_encodings):
|
||||
weight = weights[index]
|
||||
#print("processing fragment", fragment, weight)
|
||||
fragment_tokens = item_encodings[index]
|
||||
#print("fragment", fragment, "processed to", fragment_tokens)
|
||||
# trim bos and eos markers before appending
|
||||
all_tokens.extend(fragment_tokens[1:-1])
|
||||
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
|
||||
|
||||
if (len(all_tokens) + 2) > self.max_length:
|
||||
excess_token_count = (len(all_tokens) + 2) - self.max_length
|
||||
print(f"prompt is {excess_token_count} token(s) too long and has been truncated")
|
||||
all_tokens = all_tokens[:self.max_length - 2]
|
||||
|
||||
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
||||
# (77 = self.max_length)
|
||||
pad_length = self.max_length - 1 - len(all_tokens)
|
||||
all_tokens.insert(0, self.tokenizer.bos_token_id)
|
||||
all_tokens.extend([self.tokenizer.eos_token_id] * pad_length)
|
||||
per_token_weights.insert(0, 1)
|
||||
per_token_weights.extend([1] * pad_length)
|
||||
|
||||
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
|
||||
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
|
||||
#print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
|
||||
return all_tokens_tensor, per_token_weights_tensor
|
||||
|
||||
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
|
||||
'''
|
||||
Build a tensor representing the passed-in tokens, each of which has a weight.
|
||||
:param tokens: A tensor of shape (77) containing token ids (integers)
|
||||
:param per_token_weights: A tensor of shape (77) containing weights (floats)
|
||||
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
|
||||
:param kwargs: passed on to self.transformer()
|
||||
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
|
||||
'''
|
||||
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
||||
z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs)
|
||||
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
||||
|
||||
if weight_delta_from_empty:
|
||||
empty_tokens = self.tokenizer([''] * z.shape[0],
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding='max_length',
|
||||
return_tensors='pt'
|
||||
)['input_ids'].to(self.device)
|
||||
empty_z = self.transformer(input_ids=empty_tokens, **kwargs)
|
||||
z_delta_from_empty = z - empty_z
|
||||
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
||||
|
||||
weighted_z_delta_from_empty = (weighted_z-empty_z)
|
||||
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
|
||||
|
||||
#print("using empty-delta method, first 5 rows:")
|
||||
#print(weighted_z[:5])
|
||||
|
||||
return weighted_z
|
||||
|
||||
else:
|
||||
original_mean = z.mean()
|
||||
z *= batch_weights_expanded
|
||||
after_weighting_mean = z.mean()
|
||||
# correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does
|
||||
mean_correction_factor = original_mean/after_weighting_mean
|
||||
z *= mean_correction_factor
|
||||
return z
|
||||
|
||||
|
||||
class FrozenCLIPTextEmbedder(nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user