mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip implementing options in diffuse step
This commit is contained in:
parent
ee7d4d712a
commit
8273c04575
@ -16,7 +16,8 @@ from typing import Union
|
||||
import torch
|
||||
|
||||
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend
|
||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||
|
||||
|
||||
@ -65,27 +66,54 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
if wants_cross_attention_control:
|
||||
original_prompt = FlattenedPrompt()
|
||||
edited_prompt = FlattenedPrompt()
|
||||
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
|
||||
original_token_count = 0
|
||||
edited_token_count = 0
|
||||
edit_opcodes = []
|
||||
edit_options = []
|
||||
for fragment in flattened_prompt.children:
|
||||
if type(fragment) is CrossAttentionControlSubstitute:
|
||||
original_prompt.append(fragment.original)
|
||||
edited_prompt.append(fragment.edited)
|
||||
|
||||
to_replace_token_count = get_tokens_length(model, fragment.original)
|
||||
replacement_token_count = get_tokens_length(model, fragment.edited)
|
||||
edit_opcodes.append(('replace',
|
||||
original_token_count, original_token_count + to_replace_token_count,
|
||||
edited_token_count, edited_token_count + replacement_token_count
|
||||
))
|
||||
original_token_count += to_replace_token_count
|
||||
edited_token_count += replacement_token_count
|
||||
edit_options.append(fragment.options)
|
||||
#elif type(fragment) is CrossAttentionControlAppend:
|
||||
# edited_prompt.append(fragment.fragment)
|
||||
else:
|
||||
# regular fragment
|
||||
original_prompt.append(fragment)
|
||||
edited_prompt.append(fragment)
|
||||
|
||||
count = get_tokens_length(model, [fragment])
|
||||
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
|
||||
edit_options.append(None)
|
||||
original_token_count += count
|
||||
edited_token_count += count
|
||||
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt)
|
||||
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt)
|
||||
|
||||
conditioning = original_embeddings
|
||||
edited_conditioning = edited_embeddings
|
||||
edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens)
|
||||
print('got edit_opcodes', edit_opcodes, 'options', edit_options)
|
||||
else:
|
||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
|
||||
|
||||
|
||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt)
|
||||
return (unconditioning, conditioning, edited_conditioning, edit_opcodes)
|
||||
return (
|
||||
unconditioning, conditioning, edited_conditioning, edit_opcodes
|
||||
#InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=edited_conditioning,
|
||||
# edit_opcodes=edit_opcodes,
|
||||
# edit_options=edit_options)
|
||||
)
|
||||
|
||||
|
||||
def build_token_edit_opcodes(original_tokens, edited_tokens):
|
||||
@ -102,6 +130,10 @@ def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: Fl
|
||||
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
|
||||
return embeddings, tokens
|
||||
|
||||
def get_tokens_length(model, fragments: list[Fragment]):
|
||||
fragment_texts = [x.text for x in fragments]
|
||||
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
||||
return sum([len(x) for x in tokens])
|
||||
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
|
@ -557,6 +557,21 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
else:
|
||||
return batch_z
|
||||
|
||||
def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
|
||||
tokens = self.tokenizer(
|
||||
fragments,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_overflowing_tokens=False,
|
||||
padding='do_not_pad',
|
||||
return_tensors=None, # just give me a list of ints
|
||||
)['input_ids']
|
||||
if include_start_and_end_markers:
|
||||
return tokens
|
||||
else:
|
||||
return [x[1:-1] for x in tokens]
|
||||
|
||||
|
||||
@classmethod
|
||||
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
|
||||
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
|
||||
|
Loading…
Reference in New Issue
Block a user