wip implementing options in diffuse step

This commit is contained in:
Damian at mba 2022-10-22 12:15:34 +02:00
parent ee7d4d712a
commit 8273c04575
2 changed files with 50 additions and 3 deletions

View File

@ -16,7 +16,8 @@ from typing import Union
import torch
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
@ -65,27 +66,54 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
if wants_cross_attention_control:
original_prompt = FlattenedPrompt()
edited_prompt = FlattenedPrompt()
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
original_token_count = 0
edited_token_count = 0
edit_opcodes = []
edit_options = []
for fragment in flattened_prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited)
to_replace_token_count = get_tokens_length(model, fragment.original)
replacement_token_count = get_tokens_length(model, fragment.edited)
edit_opcodes.append(('replace',
original_token_count, original_token_count + to_replace_token_count,
edited_token_count, edited_token_count + replacement_token_count
))
original_token_count += to_replace_token_count
edited_token_count += replacement_token_count
edit_options.append(fragment.options)
#elif type(fragment) is CrossAttentionControlAppend:
# edited_prompt.append(fragment.fragment)
else:
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
count = get_tokens_length(model, [fragment])
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
edit_options.append(None)
original_token_count += count
edited_token_count += count
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt)
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt)
conditioning = original_embeddings
edited_conditioning = edited_embeddings
edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens)
print('got edit_opcodes', edit_opcodes, 'options', edit_options)
else:
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt)
return (unconditioning, conditioning, edited_conditioning, edit_opcodes)
return (
unconditioning, conditioning, edited_conditioning, edit_opcodes
#InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=edited_conditioning,
# edit_opcodes=edit_opcodes,
# edit_options=edit_options)
)
def build_token_edit_opcodes(original_tokens, edited_tokens):
@ -102,6 +130,10 @@ def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: Fl
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
return embeddings, tokens
def get_tokens_length(model, fragments: list[Fragment]):
fragment_texts = [x.text for x in fragments]
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
return sum([len(x) for x in tokens])
def split_weighted_subprompts(text, skip_normalize=False)->list:

View File

@ -557,6 +557,21 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
else:
return batch_z
def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
tokens = self.tokenizer(
fragments,
truncation=True,
max_length=self.max_length,
return_overflowing_tokens=False,
padding='do_not_pad',
return_tensors=None, # just give me a list of ints
)['input_ids']
if include_start_and_end_markers:
return tokens
else:
return [x[1:-1] for x in tokens]
@classmethod
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)