diff --git a/backend/server.py b/backend/server.py index 7b8a8a5a69..f14c141e12 100644 --- a/backend/server.py +++ b/backend/server.py @@ -527,7 +527,7 @@ def parameters_to_generated_image_metadata(parameters): rfc_dict["sampler"] = parameters["sampler_name"] # display weighted subprompts (liable to change) - subprompts = split_weighted_subprompts(parameters["prompt"]) + subprompts = split_weighted_subprompts(parameters["prompt"], skip_normalize=True) subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts] rfc_dict["prompt"] = subprompts diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml index 9c773077b6..baf91f6e26 100644 --- a/configs/stable-diffusion/v1-inference.yaml +++ b/configs/stable-diffusion/v1-inference.yaml @@ -76,4 +76,4 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index e3685db615..189af10c24 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -97,7 +97,8 @@ def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: Fl if type(flattened_prompt) is not FlattenedPrompt: raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead" fragments = [x.text for x in flattened_prompt.children] - embeddings, tokens = model.get_learned_conditioning([' '.join(fragments)], return_tokens=True) + weights = [x.weight for x in flattened_prompt.children] + embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights]) return embeddings, tokens diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index ef6a72a49b..1d5fb3c04a 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -329,7 +329,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False): fragment_string = x[0] - print(f"parsing fragment string \"{fragment_string}\"") + #print(f"parsing fragment string \"{fragment_string}\"") if len(fragment_string.strip()) == 0: return Fragment('') diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 8c858757eb..44e418acb1 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -183,7 +183,6 @@ class KSampler(Sampler): ) # sigmas are set up in make_schedule - we take the last steps items - total_steps = len(self.sigmas) sigmas = self.sigmas[-S-1:] # x_T is variation noise. When an init image is provided (in x0) we need to add diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index 879b85495a..e33d57fe31 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -4,7 +4,7 @@ ldm.models.diffusion.sampler Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc ''' -from enum import Enum +from math import ceil import torch import numpy as np diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index d4f059de09..0a613091d5 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,4 +1,5 @@ from enum import Enum +from math import ceil from typing import Callable import torch @@ -104,6 +105,58 @@ class InvokeAIDiffuserComponent: return combined_next_x + + # todo: make this work + @classmethod + def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) # aka sigmas + + deltas = None + uncond_latents = None + weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)] + + # below is fugly omg + num_actual_conditionings = len(c_or_weighted_c_list) + conditionings = [uc] + [c for c,weight in weighted_cond_list] + weights = [1] + [weight for c,weight in weighted_cond_list] + chunk_count = ceil(len(conditionings)/2) + deltas = None + for chunk_index in range(chunk_count): + offset = chunk_index*2 + chunk_size = min(2, len(conditionings)-offset) + + if chunk_size == 1: + c_in = conditionings[offset] + latents_a = forward_func(x_in[:-1], t_in[:-1], c_in) + latents_b = None + else: + c_in = torch.cat(conditionings[offset:offset+2]) + latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2) + + # first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining + if chunk_index == 0: + uncond_latents = latents_a + deltas = latents_b - uncond_latents + else: + deltas = torch.cat((deltas, latents_a - uncond_latents)) + if latents_b is not None: + deltas = torch.cat((deltas, latents_b - uncond_latents)) + + # merge the weighted deltas together into a single merged delta + per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device) + normalize = False + if normalize: + per_delta_weights /= torch.sum(per_delta_weights) + reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1)) + deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True) + + # old_return_value = super().forward(x, sigma, uncond, cond, cond_scale) + # assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale)))) + + return uncond_latents + deltas_merged * global_guidance_scale + + # adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 18878af443..2883f24d1a 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -440,12 +440,6 @@ class FrozenCLIPEmbedder(AbstractEncoder): def forward(self, text, **kwargs): - should_return_tokens = False - if 'return_tokens' in kwargs: - should_return_tokens = kwargs.get('return_tokens', False) - # self.transformer doesn't like having extra kwargs - kwargs.pop('return_tokens') - batch_encoding = self.tokenizer( text, truncation=True, @@ -458,22 +452,211 @@ class FrozenCLIPEmbedder(AbstractEncoder): tokens = batch_encoding['input_ids'].to(self.device) z = self.transformer(input_ids=tokens, **kwargs) - if should_return_tokens: - return z, tokens - else: - return z + return z def encode(self, text, **kwargs): return self(text, **kwargs) class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): + + fragment_weights_key = "fragment_weights" + return_tokens_key = "return_tokens" + + def forward(self, text: list, **kwargs): + ''' + + :param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different + weights shall be applied. + :param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights + for the prompt fragments. In this case text must contain batches of lists of prompt fragments. + :return: A tensor of shape (B, 77, 768) containing weighted embeddings + ''' + if self.fragment_weights_key not in kwargs: + # fallback to base class implementation + return super().forward(text, **kwargs) + + fragment_weights = kwargs[self.fragment_weights_key] + # self.transformer doesn't like receiving "fragment_weights" as an argument + kwargs.pop(self.fragment_weights_key) + + should_return_tokens = False + if self.return_tokens_key in kwargs: + should_return_tokens = kwargs.get(self.return_tokens_key, False) + # self.transformer doesn't like having extra kwargs + kwargs.pop(self.return_tokens_key) + + batch_z = None + batch_tokens = None + for fragments, weights in zip(text, fragment_weights): + + # First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively + # applying a multiplier to the CFG scale on a per-token basis). + # For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept + # captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active + # interest, however small, in redness; what the user probably intends when they attach the number 0.01 to + # "red" is to tell SD that it should almost completely *ignore* redness). + # To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt + # string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the + # closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment. + + # handle weights >=1 + tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights) + base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs) + + # this is our starting point + embeddings = base_embedding.unsqueeze(0) + per_embedding_weights = [1.0] + + # now handle weights <1 + # Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped + # with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting + # embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words + # removed. + # eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding + # for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it + # such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain". + for index, fragment_weight in enumerate(weights): + if fragment_weight < 1: + fragments_without_this = fragments[:index] + fragments[index+1:] + weights_without_this = weights[:index] + weights[index+1:] + tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this) + embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs) + + embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1) + # weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0 + # if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding + # therefore: + # fragment_weight = 1: we are at base_z => lerp weight 0 + # fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1 + # fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf + # so let's use tan(), because: + # tan is 0.0 at 0, + # 1.0 at PI/4, and + # inf at PI/2 + # -> tan((1-weight)*PI/2) should give us ideal lerp weights + epsilon = 1e-9 + fragment_weight = max(epsilon, fragment_weight) # inf is bad + embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2) + # todo handle negative weight? + + per_embedding_weights.append(embedding_lerp_weight) + + lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0) + + #print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") + + # append to batch + batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1) + batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1) + + # should have shape (B, 77, 768) + #print(f"assembled all tokens into tensor of shape {batch_z.shape}") + + if should_return_tokens: + return batch_z, batch_tokens + else: + return batch_z + @classmethod def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor: per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device) if normalize: per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights) reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,)) + #reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape) return torch.sum(embeddings * reshaped_weights, dim=1) + # lerped embeddings has shape (77, 768) + + + def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor): + ''' + + :param fragments: + :param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine. + :return: + ''' + # empty is meaningful + if len(fragments) == 0 and len(weights) == 0: + fragments = [''] + weights = [1] + item_encodings = self.tokenizer( + fragments, + truncation=True, + max_length=self.max_length, + return_overflowing_tokens=False, + padding='do_not_pad', + return_tensors=None, # just give me a list of ints + )['input_ids'] + all_tokens = [] + per_token_weights = [] + #print("all fragments:", fragments, weights) + for index, fragment in enumerate(item_encodings): + weight = weights[index] + #print("processing fragment", fragment, weight) + fragment_tokens = item_encodings[index] + #print("fragment", fragment, "processed to", fragment_tokens) + # trim bos and eos markers before appending + all_tokens.extend(fragment_tokens[1:-1]) + per_token_weights.extend([weight] * (len(fragment_tokens) - 2)) + + if (len(all_tokens) + 2) > self.max_length: + excess_token_count = (len(all_tokens) + 2) - self.max_length + print(f"prompt is {excess_token_count} token(s) too long and has been truncated") + all_tokens = all_tokens[:self.max_length - 2] + + # pad out to a 77-entry array: [eos_token, , eos_token, ..., eos_token] + # (77 = self.max_length) + pad_length = self.max_length - 1 - len(all_tokens) + all_tokens.insert(0, self.tokenizer.bos_token_id) + all_tokens.extend([self.tokenizer.eos_token_id] * pad_length) + per_token_weights.insert(0, 1) + per_token_weights.extend([1] * pad_length) + + all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device) + per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device) + #print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}") + return all_tokens_tensor, per_token_weights_tensor + + def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor: + ''' + Build a tensor representing the passed-in tokens, each of which has a weight. + :param tokens: A tensor of shape (77) containing token ids (integers) + :param per_token_weights: A tensor of shape (77) containing weights (floats) + :param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector + :param kwargs: passed on to self.transformer() + :return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings. + ''' + #print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}") + z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs) + batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape) + + if weight_delta_from_empty: + empty_tokens = self.tokenizer([''] * z.shape[0], + truncation=True, + max_length=self.max_length, + padding='max_length', + return_tensors='pt' + )['input_ids'].to(self.device) + empty_z = self.transformer(input_ids=empty_tokens, **kwargs) + z_delta_from_empty = z - empty_z + weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) + + weighted_z_delta_from_empty = (weighted_z-empty_z) + #print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() ) + + #print("using empty-delta method, first 5 rows:") + #print(weighted_z[:5]) + + return weighted_z + + else: + original_mean = z.mean() + z *= batch_weights_expanded + after_weighting_mean = z.mean() + # correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does + mean_correction_factor = original_mean/after_weighting_mean + z *= mean_correction_factor + return z class FrozenCLIPTextEmbedder(nn.Module):