Merge branch 'main' into install/refactor-configure-and-model-select

This commit is contained in:
Lincoln Stein 2023-02-22 14:22:52 -05:00 committed by GitHub
commit 16aea1e869
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 365 additions and 1780 deletions

View File

@ -25,12 +25,13 @@ from invokeai.backend.modules.parameters import parameters_to_command
import invokeai.frontend.dist as frontend
from ldm.generate import Generate
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure
from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, split_weighted_subprompts, \
get_tokenizer
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
from ldm.invoke.generator.inpaint import infill_methods
from ldm.invoke.globals import Globals, global_converted_ckpts_dir
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend
from compel.prompt_parser import Blend
from ldm.invoke.globals import global_models_dir
from ldm.invoke.merge_diffusers import merge_diffusion_models
@ -1258,7 +1259,7 @@ class InvokeAIWebServer:
parsed_prompt, _ = get_prompt_structure(
generation_parameters["prompt"])
tokens = None if type(parsed_prompt) is Blend else \
get_tokens_for_prompt(self.generate.model, parsed_prompt)
get_tokens_for_prompt_object(get_tokenizer(self.generate.model), parsed_prompt)
attention_maps_image_base64_url = None if attention_maps_image is None \
else image_to_dataURL(attention_maps_image)
@ -1383,13 +1384,6 @@ class InvokeAIWebServer:
# semantic drift
rfc_dict["sampler"] = parameters["sampler_name"]
# display weighted subprompts (liable to change)
subprompts = split_weighted_subprompts(
parameters["prompt"], skip_normalize=True
)
subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts]
rfc_dict["prompt"] = subprompts
# 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
variations = []

File diff suppressed because one or more lines are too long

View File

@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
<script type="module" crossorigin src="./assets/index-762ec810.js"></script>
<script type="module" crossorigin src="./assets/index-0e39fbc4.js"></script>
<link rel="stylesheet" href="./assets/index-14cb2922.css">
</head>

View File

@ -29,7 +29,8 @@ export declare type PromptItem = {
weight: number;
};
export declare type Prompt = Array<PromptItem>;
// TECHDEBT: We need to retain compatibility with plain prompt strings and the structure Prompt type
export declare type Prompt = Array<PromptItem> | string;
export declare type SeedWeightPair = {
seed: number;

View File

@ -1,9 +1,11 @@
import * as InvokeAI from 'app/invokeai';
import promptToString from './promptToString';
export function getPromptAndNegative(input_prompt: InvokeAI.Prompt) {
let prompt: string = promptToString(input_prompt);
let negativePrompt: string | null = null;
export function getPromptAndNegative(inputPrompt: InvokeAI.Prompt) {
let prompt: string =
typeof inputPrompt === 'string' ? inputPrompt : promptToString(inputPrompt);
let negativePrompt = '';
// Matches all negative prompts, 1st capturing group is the prompt itself
const negativePromptRegExp = new RegExp(/\[([^\][]*)]/, 'gi');

View File

@ -1,6 +1,10 @@
import * as InvokeAI from 'app/invokeai';
const promptToString = (prompt: InvokeAI.Prompt): string => {
if (typeof prompt === 'string') {
return prompt;
}
if (prompt.length === 1) {
return prompt[0].prompt;
}

View File

@ -7,7 +7,6 @@ import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover';
import { getPromptAndNegative } from 'common/util/getPromptAndNegative';
import {
setDoesCanvasNeedScaling,
setInitialCanvasImage,
@ -20,8 +19,6 @@ import UpscaleSettings from 'features/parameters/components/AdvancedParameters/U
import {
setAllParameters,
setInitialImage,
setNegativePrompt,
setPrompt,
setSeed,
} from 'features/parameters/store/generationSlice';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
@ -53,6 +50,8 @@ import {
} from 'react-icons/fa';
import { gallerySelector } from '../store/gallerySelectors';
import DeleteImageModal from './DeleteImageModal';
import { useCallback } from 'react';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
const currentImageButtonsSelector = createSelector(
[
@ -125,6 +124,7 @@ const CurrentImageButtons = () => {
const toast = useToast();
const { t } = useTranslation();
const setBothPrompts = useSetBothPrompts();
const handleClickUseAsInitialImage = () => {
if (!currentImage) return;
@ -253,18 +253,11 @@ const CurrentImageButtons = () => {
[currentImage]
);
const handleClickUsePrompt = () => {
const handleClickUsePrompt = useCallback(() => {
if (currentImage?.metadata?.image?.prompt) {
const [prompt, negativePrompt] = getPromptAndNegative(
currentImage?.metadata?.image?.prompt
);
prompt && dispatch(setPrompt(prompt));
negativePrompt
? dispatch(setNegativePrompt(negativePrompt))
: dispatch(setNegativePrompt(''));
setBothPrompts(currentImage?.metadata?.image?.prompt);
}
};
}, [currentImage?.metadata?.image?.prompt, setBothPrompts]);
useHotkeys(
'p',

View File

@ -8,8 +8,6 @@ import {
setAllImageToImageParameters,
setAllParameters,
setInitialImage,
setNegativePrompt,
setPrompt,
setSeed,
} from 'features/parameters/store/generationSlice';
import { DragEvent, memo, useState } from 'react';
@ -18,7 +16,6 @@ import DeleteImageModal from './DeleteImageModal';
import * as ContextMenu from '@radix-ui/react-context-menu';
import * as InvokeAI from 'app/invokeai';
import { getPromptAndNegative } from 'common/util/getPromptAndNegative';
import {
resizeAndScaleCanvas,
setInitialCanvasImage,
@ -26,6 +23,7 @@ import {
import { hoverableImageSelector } from 'features/gallery/store/gallerySelectors';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useTranslation } from 'react-i18next';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
interface HoverableImageProps {
image: InvokeAI.Image;
@ -55,23 +53,16 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const [isHovered, setIsHovered] = useState<boolean>(false);
const toast = useToast();
const { t } = useTranslation();
const setBothPrompts = useSetBothPrompts();
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
const handleUsePrompt = () => {
if (image.metadata) {
const [prompt, negativePrompt] = getPromptAndNegative(
image.metadata?.image?.prompt
);
prompt && dispatch(setPrompt(prompt));
negativePrompt
? dispatch(setNegativePrompt(negativePrompt))
: dispatch(setNegativePrompt(''));
if (image.metadata?.image?.prompt) {
setBothPrompts(image.metadata?.image?.prompt);
}
toast({

View File

@ -12,6 +12,7 @@ import * as InvokeAI from 'app/invokeai';
import { useAppDispatch } from 'app/storeHooks';
import promptToString from 'common/util/promptToString';
import { seedWeightsToString } from 'common/util/seedWeightPairs';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import {
setCfgScale,
setHeight,
@ -19,7 +20,6 @@ import {
setInitialImage,
setMaskPath,
setPerlin,
setPrompt,
setSampler,
setSeamless,
setSeed,
@ -129,6 +129,8 @@ const ImageMetadataViewer = memo(
({ image, styleClass }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch();
const setBothPrompts = useSetBothPrompts();
useHotkeys('esc', () => {
dispatch(setShouldShowImageDetails(false));
});
@ -152,7 +154,6 @@ const ImageMetadataViewer = memo(
seed,
steps,
strength,
denoise_str,
threshold,
type,
variations,
@ -189,8 +190,10 @@ const ImageMetadataViewer = memo(
<MetadataItem
label="Prompt"
labelPosition="top"
value={promptToString(prompt)}
onClick={() => dispatch(setPrompt(prompt))}
value={
typeof prompt === 'string' ? prompt : promptToString(prompt)
}
onClick={() => setBothPrompts(prompt)}
/>
)}
{seed !== undefined && (

View File

@ -0,0 +1,26 @@
import { getPromptAndNegative } from 'common/util/getPromptAndNegative';
import * as InvokeAI from 'app/invokeai';
import promptToString from 'common/util/promptToString';
import { useAppDispatch } from 'app/storeHooks';
import { setNegativePrompt, setPrompt } from '../store/generationSlice';
// TECHDEBT: We have two metadata prompt formats and need to handle recalling either of them.
// This hook provides a function to do that.
const useSetBothPrompts = () => {
const dispatch = useAppDispatch();
return (inputPrompt: InvokeAI.Prompt) => {
const promptString =
typeof inputPrompt === 'string'
? inputPrompt
: promptToString(inputPrompt);
const [prompt, negativePrompt] = getPromptAndNegative(promptString);
dispatch(setPrompt(prompt));
dispatch(setNegativePrompt(negativePrompt));
};
};
export default useSetBothPrompts;

File diff suppressed because one or more lines are too long

View File

@ -981,7 +981,7 @@ class Generate:
ti_path, defer_injecting_tokens=True
)
print(
f'>> Textual inversion triggers: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}'
f'>> Textual inversion triggers: {", ".join(sorted(self.model.textual_inversion_manager.get_all_trigger_strings()))}'
)
self.model_name = model_name

View File

@ -9,6 +9,8 @@ from typing import Union
import click
from compel import PromptParser
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@ -25,7 +27,6 @@ from .image_util import make_grid
from .log import write_log
from .model_manager import ModelManager
from .pngwriter import PngWriter, retrieve_metadata, write_metadata
from .prompt_parser import PromptParser
from .readline import Completer, get_completer
from ..util import url_attachment_name

View File

@ -97,8 +97,9 @@ from typing import List
import ldm.invoke
import ldm.invoke.pngwriter
from ldm.invoke.conditioning import split_weighted_subprompts
from ldm.invoke.globals import Globals
from ldm.invoke.prompt_parser import split_weighted_subprompts
APP_ID = ldm.invoke.__app_id__
APP_NAME = ldm.invoke.__app_name__

View File

@ -7,61 +7,116 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
'''
import re
from typing import Union
from typing import Union, Optional, Any
import torch
from transformers import CLIPTokenizer, CLIPTextModel
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment
from ..models.diffusion import cross_attention_control
from compel import Compel
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser
from .devices import torch_dtype
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
from ..modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
from ldm.invoke.globals import Globals
def get_tokenizer(model) -> CLIPTokenizer:
# TODO remove legacy ckpt fallback handling
return (getattr(model, 'tokenizer', None) # diffusers
or model.cond_stage_model.tokenizer) # ldm
def get_text_encoder(model) -> Any:
# TODO remove legacy ckpt fallback handling
return (getattr(model, 'text_encoder', None) # diffusers
or UnsqueezingLDMTransformer(model.cond_stage_model.transformer)) # ldm
class UnsqueezingLDMTransformer:
def __init__(self, ldm_transformer):
self.ldm_transformer = ldm_transformer
@property
def device(self):
return self.ldm_transformer.device
def __call__(self, *args, **kwargs):
insufficiently_unsqueezed_tensor = self.ldm_transformer(*args, **kwargs)
return insufficiently_unsqueezed_tensor.unsqueeze(0)
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
prompt, negative_prompt = get_prompt_structure(prompt_string,
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
tokenizer = get_tokenizer(model)
text_encoder = get_text_encoder(model)
compel = Compel(tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype)
return conditioning
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_prompt: FlattenedPrompt|Blend
if legacy_blend is not None:
positive_prompt = legacy_blend
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
if log_tokens or getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
tokens_count = get_max_token_count(tokenizer, positive_prompt)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get(
'cross_attention_control', None))
return uc, c, ec
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> (
Union[FlattenedPrompt, Blend], FlattenedPrompt):
"""
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
"""
prompt, negative_prompt = _parse_prompt_string(prompt_string,
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
return prompt, negative_prompt
Union[FlattenedPrompt, Blend], FlattenedPrompt):
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_prompt: FlattenedPrompt|Blend
if legacy_blend is not None:
positive_prompt = legacy_blend
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
return positive_prompt, negative_prompt
def get_max_token_count(tokenizer, prompt: FlattenedPrompt|Blend, truncate_if_too_long=True) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
return max([get_max_token_count(tokenizer, c, truncate_if_too_long) for c in blend.prompts])
else:
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
if type(parsed_prompt) is Blend:
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
text_fragments = [x.text if type(x) is Fragment else
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
str(x))
for x in parsed_prompt.children]
text = " ".join(text_fragments)
tokens = model.cond_stage_model.tokenizer.tokenize(text)
tokens = tokenizer.tokenize(text)
if truncate_if_too_long:
max_tokens_length = model.cond_stage_model.max_length - 2 # typically 75
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length]
return tokens
def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]:
# Extract Unconditioned Words From Prompt
def split_prompt_to_positive_and_negative(prompt_string_uncleaned):
unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]'
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
if len(unconditionals) > 0:
unconditioned_words = ' '.join(unconditionals)
@ -71,210 +126,57 @@ def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=Fa
prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
else:
prompt_string_cleaned = prompt_string_uncleaned
pp = PromptParser()
parsed_prompt: Union[FlattenedPrompt, Blend] = None
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned, skip_normalize_legacy_blend)
if legacy_blend is not None:
parsed_prompt = legacy_blend
else:
# we don't support conjunctions for now
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
return parsed_prompt, parsed_negative_prompt
return prompt_string_cleaned, unconditioned_words
def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt,
model, log_tokens=False) \
-> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]:
"""
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
"""
def log_tokenization(positive_prompt: Blend | FlattenedPrompt,
negative_prompt: Blend | FlattenedPrompt,
tokenizer):
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
if log_tokens or getattr(Globals, "log_tokenization", False):
print(f"\n>> [TOKENLOG] Parsed Prompt: {parsed_prompt}")
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {parsed_negative_prompt}")
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
log_tokenization_for_prompt_object(negative_prompt, tokenizer, display_label_prefix="(negative prompt)")
conditioning = None
cac_args: cross_attention_control.Arguments = None
if type(parsed_prompt) is Blend:
conditioning = _get_conditioning_for_blend(model, parsed_prompt, log_tokens)
elif type(parsed_prompt) is FlattenedPrompt:
if parsed_prompt.wants_cross_attention_control:
conditioning, cac_args = _get_conditioning_for_cross_attention_control(model, parsed_prompt, log_tokens)
def log_tokenization_for_prompt_object(p: Blend | FlattenedPrompt, tokenizer, display_label_prefix=None):
display_label_prefix = display_label_prefix or ""
if type(p) is Blend:
blend: Blend = p
for i, c in enumerate(blend.prompts):
log_tokenization_for_prompt_object(
c, tokenizer,
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})")
elif type(p) is FlattenedPrompt:
flattened_prompt: FlattenedPrompt = p
if flattened_prompt.wants_cross_attention_control:
original_fragments = []
edited_fragments = []
for f in flattened_prompt.children:
if type(f) is CrossAttentionControlSubstitute:
original_fragments += f.original
edited_fragments += f.edited
else:
original_fragments.append(f)
edited_fragments.append(f)
original_text = " ".join([x.text for x in original_fragments])
log_tokenization_for_text(original_text, tokenizer,
display_label=f"{display_label_prefix}(.swap originals)")
edited_text = " ".join([x.text for x in edited_fragments])
log_tokenization_for_text(edited_text, tokenizer,
display_label=f"{display_label_prefix}(.swap replacements)")
else:
conditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
parsed_prompt,
log_tokens=log_tokens,
log_display_label="(prompt)")
else:
raise ValueError(f"parsed_prompt is '{type(parsed_prompt)}' which is not a supported prompt type")
unconditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
parsed_negative_prompt,
log_tokens=log_tokens,
log_display_label="(unconditioning)")
if isinstance(conditioning, dict):
# hybrid conditioning is in play
unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning)
if cac_args is not None:
print(
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
cac_args = None
if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt
all_token_sequences = [get_tokens_for_prompt(model, p) for p in blend.prompts]
longest_token_sequence = max(all_token_sequences, key=lambda t: len(t))
eos_token_index = len(longest_token_sequence)+1
else:
tokens = get_tokens_for_prompt(model, parsed_prompt)
eos_token_index = len(tokens)+1
return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=eos_token_index + 1,
cross_attention_control_args=cac_args
)
)
text = " ".join([x.text for x in flattened_prompt.children])
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True):
original_prompt = FlattenedPrompt()
edited_prompt = FlattenedPrompt()
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
original_token_count = 0
edited_token_count = 0
edit_options = []
edit_opcodes = []
# beginning of sequence
edit_opcodes.append(
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
for fragment in prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited)
to_replace_token_count = _get_tokens_length(model, fragment.original)
replacement_token_count = _get_tokens_length(model, fragment.edited)
edit_opcodes.append(('replace',
original_token_count, original_token_count + to_replace_token_count,
edited_token_count, edited_token_count + replacement_token_count
))
original_token_count += to_replace_token_count
edited_token_count += replacement_token_count
edit_options.append(fragment.options)
# elif type(fragment) is CrossAttentionControlAppend:
# edited_prompt.append(fragment.fragment)
else:
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
count = _get_tokens_length(model, [fragment])
edit_opcodes.append(('equal', original_token_count, original_token_count + count, edited_token_count,
edited_token_count + count))
edit_options.append(None)
original_token_count += count
edited_token_count += count
# end of sequence
edit_opcodes.append(
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
edit_options.append(None)
original_token_count += 1
edited_token_count += 1
original_embeddings, original_tokens = _get_embeddings_and_tokens_for_prompt(model,
original_prompt,
log_tokens=log_tokens,
log_display_label="(.swap originals)")
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = _get_embeddings_and_tokens_for_prompt(model,
edited_prompt,
log_tokens=log_tokens,
log_display_label="(.swap replacements)")
conditioning = original_embeddings
edited_conditioning = edited_embeddings
# print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = cross_attention_control.Arguments(
edited_conditioning=edited_conditioning,
edit_opcodes=edit_opcodes,
edit_options=edit_options
)
return conditioning, cac_args
def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
embeddings_to_blend = None
for i, flattened_prompt in enumerate(blend.prompts):
this_embedding, _ = _get_embeddings_and_tokens_for_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedPromptFragmentsToEmbeddingsConverter.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
blend.weights,
normalize=blend.normalize_weights)
return conditioning
def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False,
log_display_label: str = None):
if type(flattened_prompt) is not FlattenedPrompt:
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
fragments = [x.text for x in flattened_prompt.children]
weights = [x.weight for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
if log_tokens or getattr(Globals, "log_tokenization", False):
text = " ".join(fragments)
log_tokenization(text, model, display_label=log_display_label)
return embeddings, tokens
def _get_tokens_length(model, fragments: list[Fragment]):
fragment_texts = [x.text for x in fragments]
tokens = model.cond_stage_model.get_token_ids(fragment_texts, include_start_and_end_markers=False)
return sum([len(x) for x in tokens])
def _flatten_hybrid_conditioning(uncond, cond):
'''
This handles the choice between a conditional conditioning
that is a tensor (used by cross attention) vs one that has additional
dimensions as well, as used by 'hybrid'
'''
assert isinstance(uncond, dict)
assert isinstance(cond, dict)
cond_flattened = dict()
for k in cond:
if isinstance(cond[k], list):
cond_flattened[k] = [
torch.cat([uncond[k][i], cond[k][i]])
for i in range(len(cond[k]))
]
else:
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
return uncond, cond_flattened
def log_tokenization(text, model, display_label=None):
def log_tokenization_for_text(text, tokenizer, display_label=None):
""" shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
"""
tokens = model.cond_stage_model.tokenizer.tokenize(text)
tokens = tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
@ -284,7 +186,7 @@ def log_tokenization(text, model, display_label=None):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
if i < tokenizer.model_max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
@ -293,7 +195,58 @@ def log_tokenization(text, model, display_label=None):
if usedTokens > 0:
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
print(f'{tokenized}\x1b[0m')
if discarded != "":
print(f'\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):')
print(f'{discarded}\x1b[0m')
def try_parse_legacy_blend(text: str, skip_normalize: bool=False) -> Optional[Blend]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1:
return None
strings = [x[0] for x in weighted_subprompts]
weights = [x[1] for x in weighted_subprompts]
pp = PromptParser()
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
def split_weighted_subprompts(text, skip_normalize=False)->list:
"""
Legacy blend parsing.
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re.compile("""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight> # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""", re.VERBOSE)
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]

View File

@ -32,7 +32,7 @@ from ldm.modules.textual_inversion_manager import TextualInversionManager
from ..devices import normalize_device, CPU_DEVICE
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
from compel import EmbeddingsProvider
@dataclass
@ -295,7 +295,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
text_encoder=self.text_encoder,
full_precision=use_full_precision)
# InvokeAI's interface for text embeddings and whatnot
self.prompt_fragments_to_embeddings_converter = WeightedPromptFragmentsToEmbeddingsConverter(
self.embeddings_provider = EmbeddingsProvider(
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
textual_inversion_manager=self.textual_inversion_manager
@ -727,15 +727,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
"""
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
"""
return self.prompt_fragments_to_embeddings_converter.get_embeddings_for_weighted_prompt_fragments(
text=c,
fragment_weights=fragment_weights,
return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments(
text_batch=c,
fragment_weights_batch=fragment_weights,
should_return_tokens=return_tokens,
device=self._model_group.device_for(self.unet))
@property
def cond_stage_model(self):
return self.prompt_fragments_to_embeddings_converter
return self.embeddings_provider
@torch.inference_mode()
def _tokenize(self, prompt: Union[str, List[str]]):

View File

@ -1,655 +0,0 @@
import string
from typing import Union, Optional
import re
import pyparsing as pp
'''
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
weighted subprompts.
Useful class exports:
PromptParser - parses prompts
Useful function exports:
split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated
'''
class Prompt():
"""
Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of
the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects
that are to be blended together are stored individuall as Prompt objects.
Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction
to produce a FlattenedPrompt.
"""
def __init__(self, parts: list):
for c in parts:
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} ({c}), only {[c.__name__ for c in BaseFragment.__subclasses__()]} are allowed")
self.children = parts
def __repr__(self):
return f"Prompt:{self.children}"
def __eq__(self, other):
return type(other) is Prompt and other.children == self.children
class BaseFragment:
pass
class FlattenedPrompt():
"""
A Prompt that has been passed through flatten(). Its children can be readily tokenized.
"""
def __init__(self, parts: list=[]):
self.children = []
for part in parts:
self.append(part)
def append(self, fragment: Union[list, BaseFragment, tuple]):
# verify type correctness
if type(fragment) is list:
for x in fragment:
self.append(x)
elif issubclass(type(fragment), BaseFragment):
self.children.append(fragment)
elif type(fragment) is tuple:
# upgrade tuples to Fragments
if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int):
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
self.children.append(Fragment(fragment[0], fragment[1]))
else:
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
@property
def is_empty(self):
return len(self.children) == 0 or \
(len(self.children) == 1 and len(self.children[0].text) == 0)
@property
def wants_cross_attention_control(self):
return any(
[issubclass(type(x), CrossAttentionControlledFragment) for x in self.children]
)
def __repr__(self):
return f"FlattenedPrompt:{self.children}"
def __eq__(self, other):
return type(other) is FlattenedPrompt and other.children == self.children
class Fragment(BaseFragment):
"""
A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer.
"""
def __init__(self, text: str, weight: float=1):
assert(type(text) is str)
if '\\"' in text or '\\(' in text or '\\)' in text:
#print("Fragment converting escaped \( \) \\\" into ( ) \"")
text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"')
self.text = text
self.weight = float(weight)
def __repr__(self):
return "Fragment:'"+self.text+"'@"+str(self.weight)
def __eq__(self, other):
return type(other) is Fragment \
and other.text == self.text \
and other.weight == self.weight
class Attention():
"""
Nestable weight control for fragments. Each object in the children array may in turn be an Attention object;
weights should be considered to accumulate as the tree is traversed to deeper levels of nesting.
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
"""
def __init__(self, weight: float, children: list):
if type(weight) is not float:
raise PromptParser.ParsingException(
f"Attention weight must be float (got {type(weight).__name__} {weight})")
self.weight = weight
if type(children) is not list:
raise PromptParser.ParsingException(f"cannot make Attention with non-list of children (got {type(children)})")
assert(type(children) is list)
self.children = children
#print(f"A: requested attention '{children}' to {weight}")
def __repr__(self):
return f"Attention:{self.children} * {self.weight}"
def __eq__(self, other):
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
class CrossAttentionControlledFragment(BaseFragment):
pass
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
"""
A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt.
Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an
"edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively,
the result should be an "edited" image that looks like the "original" image with concepts swapped.
eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look
almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The
CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped:
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')])
or it may represent a larger portion of the token sequence:
CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')],
edited=[Fragment('a smiling dog sitting on a car')])
In either case expect it to be embedded in a Prompt or FlattenedPrompt:
FlattenedPrompt([
Fragment('a'),
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]),
Fragment('sitting on a car')
])
"""
def __init__(self, original: list, edited: list, options: dict=None):
self.original = original if len(original)>0 else [Fragment('')]
self.edited = edited if len(edited)>0 else [Fragment('')]
default_options = {
's_start': 0.0,
's_end': 0.2062994740159002, # ~= shape_freedom=0.5
't_start': 0.1,
't_end': 1.0
}
merged_options = default_options
if options is not None:
shape_freedom = options.pop('shape_freedom', None)
if shape_freedom is not None:
# high shape freedom = SD can do what it wants with the shape of the object
# high shape freedom => s_end = 0
# low shape freedom => s_end = 1
# shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0,
# and there is very little perceptible difference as s_end increases above 0.5
# so for shape_freedom = 0.5 we probably want s_end to be 0.2
# -> cube root and subtract from 1.0
merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.)
#print('converted shape_freedom argument to', merged_options)
merged_options.update(options)
self.options = merged_options
def __repr__(self):
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})"
def __eq__(self, other):
return type(other) is CrossAttentionControlSubstitute \
and other.original == self.original \
and other.edited == self.edited \
and other.options == self.options
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
def __init__(self, fragment: Fragment):
self.fragment = fragment
def __repr__(self):
return "CrossAttentionControlAppend:",self.fragment
def __eq__(self, other):
return type(other) is CrossAttentionControlAppend \
and other.fragment == self.fragment
class Conjunction():
"""
Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged
by weighted sum in latent space.
"""
def __init__(self, prompts: list, weights: list = None):
# force everything to be a Prompt
#print("making conjunction with", prompts, "types", [type(p).__name__ for p in prompts])
self.prompts = [x if (type(x) is Prompt
or type(x) is Blend
or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts]
self.weights = [1.0]*len(self.prompts) if (weights is None or len(weights)==0) else list(weights)
if len(self.weights) != len(self.prompts):
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
self.type = 'AND'
def __repr__(self):
return f"Conjunction:{self.prompts} | weights {self.weights}"
def __eq__(self, other):
return type(other) is Conjunction \
and other.prompts == self.prompts \
and other.weights == self.weights
class Blend():
"""
Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a
weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the
Prompts.
"""
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
#print("making Blend with prompts", prompts, "and weights", weights)
weights = [1.0]*len(prompts) if (weights is None or len(weights)==0) else list(weights)
if len(prompts) != len(weights):
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
for p in prompts:
if type(p) is not Prompt and type(p) is not FlattenedPrompt:
raise(PromptParser.ParsingException(f"{type(p)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
for f in p.children:
if isinstance(f, CrossAttentionControlSubstitute):
raise(PromptParser.ParsingException(f"while parsing Blend: sorry, you cannot do .swap() as part of a Blend"))
# upcast all lists to Prompt objects
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
else Prompt(x)
for x in prompts]
self.prompts = prompts
self.weights = weights
self.normalize_weights = normalize_weights
@property
def wants_cross_attention_control(self):
# blends cannot cross-attention control
return False
def __repr__(self):
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
def __eq__(self, other):
return other.__repr__() == self.__repr__()
class PromptParser():
class ParsingException(Exception):
pass
class UnrecognizedOperatorException(ParsingException):
def __init__(self, operator:str):
super().__init__("Unrecognized operator: " + operator)
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
def parse_conjunction(self, prompt: str) -> Conjunction:
'''
:param prompt: The prompt string to parse
:return: a Conjunction representing the parsed results.
'''
#print(f"!!parsing '{prompt}'")
if len(prompt.strip()) == 0:
return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0])
root = self.conjunction.parse_string(prompt)
#print(f"'{prompt}' parsed to root", root)
#fused = fuse_fragments(parts)
#print("fused to", fused)
return self.flatten(root[0])
def parse_legacy_blend(self, text: str, skip_normalize: bool = False) -> Optional[Blend]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1:
return None
strings = [x[0] for x in weighted_subprompts]
weights = [x[1] for x in weighted_subprompts]
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
"""
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
that can be readily tokenized without the need to walk a complex tree structure.
:param root: The Conjunction to flatten.
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
"""
def fuse_fragments(items):
# print("fusing fragments in ", items)
result = []
for x in items:
if type(x) is CrossAttentionControlSubstitute:
original_fused = fuse_fragments(x.original)
edited_fused = fuse_fragments(x.edited)
result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options))
else:
last_weight = result[-1].weight \
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
else None
this_text = x.text
this_weight = x.weight
if last_weight is not None and last_weight == this_weight:
last_text = result[-1].text
result[-1] = Fragment(last_text + ' ' + this_text, last_weight)
else:
result.append(x)
return result
def flatten_internal(node, weight_scale, results, prefix):
verbose and print(prefix + "flattening", node, "...")
if type(node) is pp.ParseResults or type(node) is list:
for x in node:
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
#print(prefix, " ParseResults expanded, results is now", results)
elif type(node) is Attention:
# if node.weight < 1:
# todo: inject a blend when flattening attention with weight <1"
for index,c in enumerate(node.children):
results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ")
elif type(node) is Fragment:
results += [Fragment(node.text, node.weight*weight_scale)]
elif type(node) is CrossAttentionControlSubstitute:
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
results += [CrossAttentionControlSubstitute(original, edited, options=node.options)]
elif type(node) is Blend:
flattened_subprompts = []
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
for prompt in node.prompts:
# prompt is a list
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)]
elif type(node) is Prompt:
#print(prefix + "about to flatten Prompt with children", node.children)
flattened_prompt = []
for child in node.children:
flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ')
results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))]
#print(prefix + "after flattening Prompt, results is", results)
else:
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
verbose and print(prefix + "-> after flattening", type(node).__name__, "results is", results)
return results
verbose and print("flattening", root)
flattened_parts = []
for part in root.prompts:
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
verbose and print("flattened to", flattened_parts)
weights = root.weights
return Conjunction(flattened_parts, weights)
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
def make_operator_object(x):
#print('making operator for', x)
target = x[0]
operator = x[1]
arguments = x[2]
if operator == '.attend':
weight_raw = arguments[0]
weight = 1.0
if type(weight_raw) is float or type(weight_raw) is int:
weight = weight_raw
elif type(weight_raw) is str:
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
weight = pow(base, len(weight_raw))
return Attention(weight=weight, children=[x for x in x[0]])
elif operator == '.swap':
return CrossAttentionControlSubstitute(target, arguments, x.as_dict())
elif operator == '.blend':
prompts = [Prompt(p) for p in x[0]]
weights_raw = x[2]
normalize_weights = True
if len(weights_raw) > 0 and weights_raw[-1][0] == 'no_normalize':
normalize_weights = False
weights_raw = weights_raw[:-1]
weights = [float(w[0]) for w in weights_raw]
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize_weights)
elif operator == '.and' or operator == '.add':
prompts = [Prompt(p) for p in x[0]]
weights = [float(w[0]) for w in x[2]]
return Conjunction(prompts=prompts, weights=weights)
raise PromptParser.UnrecognizedOperatorException(operator)
def parse_fragment_str(x, expression: pp.ParseExpression, in_quotes: bool = False, in_parens: bool = False):
#print(f"parsing fragment string for {x}")
fragment_string = x[0]
if len(fragment_string.strip()) == 0:
return Fragment('')
if in_quotes:
# escape unescaped quotes
fragment_string = fragment_string.replace('"', '\\"')
try:
result = (expression + pp.StringEnd()).parse_string(fragment_string)
#print("parsed to", result)
return result
except pp.ParseException as e:
#print("parse_fragment_str couldn't parse prompt string:", e)
raise
# meaningful symbols
lparen = pp.Literal("(").suppress()
rparen = pp.Literal(")").suppress()
quote = pp.Literal('"').suppress()
comma = pp.Literal(",").suppress()
dot = pp.Literal(".").suppress()
equals = pp.Literal("=").suppress()
escaped_lparen = pp.Literal('\\(')
escaped_rparen = pp.Literal('\\)')
escaped_quote = pp.Literal('\\"')
escaped_comma = pp.Literal('\\,')
escaped_dot = pp.Literal('\\.')
escaped_plus = pp.Literal('\\+')
escaped_minus = pp.Literal('\\-')
escaped_equals = pp.Literal('\\=')
syntactic_symbols = {
'(': escaped_lparen,
')': escaped_rparen,
'"': escaped_quote,
',': escaped_comma,
'.': escaped_dot,
'+': escaped_plus,
'-': escaped_minus,
'=': escaped_equals,
}
syntactic_chars = "".join(syntactic_symbols.keys())
# accepts int or float notation, always maps to float
number = pp.pyparsing_common.real | \
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
# for options
keyword = pp.Word(pp.alphanums + '_')
# a word that absolutely does not contain any meaningful syntax
non_syntax_word = pp.Combine(pp.OneOrMore(pp.MatchFirst([
pp.Or(syntactic_symbols.values()),
pp.one_of(['-', '+']) + pp.NotAny(pp.White() | pp.Char(syntactic_chars) | pp.StringEnd()),
# build character-by-character
pp.CharsNotIn(string.whitespace + syntactic_chars, exact=1)
])))
non_syntax_word.set_parse_action(lambda x: [Fragment(t) for t in x])
non_syntax_word.set_name('non_syntax_word')
non_syntax_word.set_debug(False)
# a word that can contain any character at all - greedily consumes syntax, so use with care
free_word = pp.CharsNotIn(string.whitespace).set_parse_action(lambda x: Fragment(x[0]))
free_word.set_name('free_word')
free_word.set_debug(False)
# ok here we go. forward declare some things..
attention = pp.Forward()
cross_attention_substitute = pp.Forward()
parenthesized_fragment = pp.Forward()
quoted_fragment = pp.Forward()
# the types of things that can go into a fragment, consisting of syntax-full and/or strictly syntax-free components
fragment_part_expressions = [
attention,
cross_attention_substitute,
parenthesized_fragment,
quoted_fragment,
non_syntax_word
]
# a fragment that is permitted to contain commas
fragment_including_commas = pp.ZeroOrMore(pp.MatchFirst(
fragment_part_expressions + [
pp.Literal(',').set_parse_action(lambda x: Fragment(x[0]))
]
))
# a fragment that is not permitted to contain commas
fragment_excluding_commas = pp.ZeroOrMore(pp.MatchFirst(
fragment_part_expressions
))
# a fragment in double quotes (may be nested)
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, fragment_including_commas, in_quotes=True))
# a fragment inside parentheses (may be nested)
parenthesized_fragment << (lparen + fragment_including_commas + rparen)
parenthesized_fragment.set_name('parenthesized_fragment')
parenthesized_fragment.set_debug(False)
# a string of the form (<keyword>=<float|keyword> | <float> | <keyword>) where keyword is alphanumeric + '_'
option = pp.Group(pp.MatchFirst([
keyword + equals + (number | keyword), # option=value
number.copy().set_parse_action(pp.token_map(str)), # weight
keyword # flag
]))
# options for an operator, eg "s_start=0.1, 0.3, no_normalize"
options = pp.Dict(pp.Optional(pp.delimited_list(option)))
options.set_name('options')
options.set_debug(False)
# a fragment which can be used as the target for an operator - either quoted or in parentheses, or a bare vanilla word
potential_operator_target = (quoted_fragment | parenthesized_fragment | non_syntax_word)
# a fragment whose weight has been increased or decreased by a given amount
attention_weight_operator = pp.Word('+') | pp.Word('-') | number
attention_explicit = (
pp.Group(potential_operator_target)
+ pp.Literal('.attend')
+ lparen
+ pp.Group(attention_weight_operator)
+ rparen
)
attention_explicit.set_parse_action(make_operator_object)
attention_implicit = (
pp.Group(potential_operator_target)
+ pp.NotAny(pp.White()) # do not permit whitespace between term and operator
+ pp.Group(attention_weight_operator)
)
attention_implicit.set_parse_action(lambda x: make_operator_object([x[0], '.attend', x[1]]))
attention << (attention_explicit | attention_implicit)
attention.set_name('attention')
attention.set_debug(False)
# cross-attention control by swapping one fragment for another
cross_attention_substitute << (
pp.Group(potential_operator_target).set_name('ca-target').set_debug(False)
+ pp.Literal(".swap").set_name('ca-operator').set_debug(False)
+ lparen
+ pp.Group(fragment_excluding_commas).set_name('ca-replacement').set_debug(False)
+ pp.Optional(comma + options).set_name('ca-options').set_debug(False)
+ rparen
)
cross_attention_substitute.set_name('cross_attention_substitute')
cross_attention_substitute.set_debug(False)
cross_attention_substitute.set_parse_action(make_operator_object)
# an entire self-contained prompt, which can be used in a Blend or Conjunction
prompt = pp.ZeroOrMore(pp.MatchFirst([
cross_attention_substitute,
attention,
quoted_fragment,
parenthesized_fragment,
free_word,
pp.White().suppress()
]))
quoted_prompt = quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, prompt, in_quotes=True))
# a blend/lerp between the feature vectors for two or more prompts
blend = (
lparen
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('bl-target').set_debug(False)
+ rparen
+ pp.Literal(".blend").set_name('bl-operator').set_debug(False)
+ lparen
+ pp.Group(options).set_name('bl-options').set_debug(False)
+ rparen
)
blend.set_name('blend')
blend.set_debug(False)
blend.set_parse_action(make_operator_object)
# an operator to direct stable diffusion to step multiple times, once for each target, and then add the results together with different weights
explicit_conjunction = (
lparen
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('cj-target').set_debug(False)
+ rparen
+ pp.one_of([".and", ".add"]).set_name('cj-operator').set_debug(False)
+ lparen
+ pp.Group(options).set_name('cj-options').set_debug(False)
+ rparen
)
explicit_conjunction.set_name('explicit_conjunction')
explicit_conjunction.set_debug(False)
explicit_conjunction.set_parse_action(make_operator_object)
# by default a prompt consists of a Conjunction with a single term
implicit_conjunction = (blend | pp.Group(prompt)) + pp.StringEnd()
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
conjunction = (explicit_conjunction | implicit_conjunction)
return conjunction, prompt
def split_weighted_subprompts(text, skip_normalize=False)->list:
"""
Legacy blend parsing.
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re.compile("""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight> # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""", re.VERBOSE)
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]

View File

@ -453,7 +453,7 @@ def main():
'** Not enough window space for the interface. Please make your window larger and try again.'
)
else:
print(f"** A layout error has occurred: {str(e)}")
print(f"** An error has occurred: {str(e)}")
sys.exit(-1)

View File

@ -430,7 +430,7 @@ class TextualInversionDataset(Dataset):
placeholder_token="*",
center_crop=False,
):
self.data_root = data_root
self.data_root = Path(data_root)
self.tokenizer = tokenizer
self.learnable_property = learnable_property
self.size = size
@ -439,9 +439,9 @@ class TextualInversionDataset(Dataset):
self.flip_p = flip_p
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
if os.path.isfile(file_path) and file_path.endswith(('.png','.PNG','.jpg','.JPG','.jpeg','.JPEG','.gif','.GIF'))
self.data_root / file_path
for file_path in self.data_root.iterdir()
if file_path.is_file() and file_path.name.endswith(('.png','.PNG','.jpg','.JPG','.jpeg','.JPEG','.gif','.GIF'))
]
self.num_images = len(self.image_paths)

View File

@ -1,3 +1,8 @@
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
import enum
import math
from typing import Optional, Callable
@ -6,35 +11,13 @@ import psutil
import torch
import diffusers
from torch import nn
from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.cross_attention import AttnProcessor
from ldm.invoke.devices import torch_dtype
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
"""
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768]
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required)
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes.
"""
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
self.edited_conditioning = edited_conditioning
self.edit_opcodes = edit_opcodes
if edited_conditioning is not None:
assert len(edit_opcodes) == len(edit_options), \
"there must be 1 edit_options dict for each edit_opcodes tuple"
non_none_edit_options = [x for x in edit_options if x is not None]
assert len(non_none_edit_options)>0, "missing edit_options"
if len(non_none_edit_options)>1:
print('warning: cross-attention control options are not working properly for >1 edit')
self.edit_options = non_none_edit_options[0]
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
@ -319,7 +302,6 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
Inject attention parameters and functions into the passed in model to enable cross attention editing.
:param model: The unet model to inject into.
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
:return: None
"""
@ -523,7 +505,7 @@ from dataclasses import field, dataclass
import torch
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor, AttnProcessor
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor
@dataclass

View File

@ -1,236 +0,0 @@
import math
import torch
from transformers import CLIPTokenizer, CLIPTextModel
from ldm.invoke.devices import torch_dtype
from ldm.modules.textual_inversion_manager import TextualInversionManager
class WeightedPromptFragmentsToEmbeddingsConverter():
def __init__(self,
tokenizer: CLIPTokenizer, # converts strings to lists of int token ids
text_encoder: CLIPTextModel, # convert a list of int token ids to a tensor of embeddings
textual_inversion_manager: TextualInversionManager = None
):
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.textual_inversion_manager = textual_inversion_manager
@property
def max_length(self):
return self.tokenizer.model_max_length
def get_embeddings_for_weighted_prompt_fragments(self,
text: list[list[str]],
fragment_weights: list[list[float]],
should_return_tokens: bool = False,
device='cpu'
) -> torch.Tensor:
'''
:param text: A list of fragments of text to which different weights are to be applied.
:param fragment_weights: A batch of lists of weights, one for each entry in `fragments`.
:return: A tensor of shape `[1, 77, token_dim]` containing weighted embeddings where token_dim is 768 for SD1
and 1280 for SD2
'''
if len(text) != len(fragment_weights):
raise ValueError(f"lengths of text and fragment_weights lists are not the same ({len(text)} != {len(fragment_weights)})")
batch_z = None
batch_tokens = None
for fragments, weights in zip(text, fragment_weights):
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
# applying a multiplier to the CFG scale on a per-token basis).
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
# captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
# interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
# "red" is to tell SD that it should almost completely *ignore* redness).
# To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
# string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
# closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.
# handle weights >=1
tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments, weights, device=device)
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights)
# this is our starting point
embeddings = base_embedding.unsqueeze(0)
per_embedding_weights = [1.0]
# now handle weights <1
# Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
# with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
# embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
# removed.
# eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
# for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
for index, fragment_weight in enumerate(weights):
if fragment_weight < 1:
fragments_without_this = fragments[:index] + fragments[index+1:]
weights_without_this = weights[:index] + weights[index+1:]
tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments_without_this, weights_without_this, device=device)
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights)
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
# therefore:
# fragment_weight = 1: we are at base_z => lerp weight 0
# fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
# fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
# so let's use tan(), because:
# tan is 0.0 at 0,
# 1.0 at PI/4, and
# inf at PI/2
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
epsilon = 1e-9
fragment_weight = max(epsilon, fragment_weight) # inf is bad
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
# todo handle negative weight?
per_embedding_weights.append(embedding_lerp_weight)
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
# append to batch
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
# should have shape (B, 77, 768)
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
if should_return_tokens:
return batch_z, batch_tokens
else:
return batch_z
def get_token_ids(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
"""
Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like
`[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if
`include_start_and_end_markers` is `False`. Each list will be restricted to the maximum permitted length
(typically 75 tokens + eos/bos markers).
:param fragments: The strings to convert.
:param include_start_and_end_markers:
:return:
"""
# for args documentation see ENCODE_KWARGS_DOCSTRING in tokenization_utils_base.py (in `transformers` lib)
token_ids_list = self.tokenizer(
fragments,
truncation=True,
max_length=self.max_length,
return_overflowing_tokens=False,
padding='do_not_pad',
return_tensors=None, # just give me lists of ints
)['input_ids']
result = []
for token_ids in token_ids_list:
# trim eos/bos
token_ids = token_ids[1:-1]
# pad for textual inversions with vector length >1
token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(token_ids)
# restrict length to max_length-2 (leaving room for bos/eos)
token_ids = token_ids[0:self.max_length - 2]
# add back eos/bos if requested
if include_start_and_end_markers:
token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id]
result.append(token_ids)
return result
@classmethod
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
if normalize:
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
return torch.sum(embeddings * reshaped_weights, dim=1)
# lerped embeddings has shape (77, 768)
def get_token_ids_and_expand_weights(self, fragments: list[str], weights: list[float], device: str) -> (torch.Tensor, torch.Tensor):
'''
Given a list of text fragments and corresponding weights: tokenize each fragment, append the token sequences
together and return a padded token sequence starting with the bos marker, ending with the eos marker, and padded
or truncated as appropriate to `self.max_length`. Also return a list of weights expanded from the passed-in
weights to match each token.
:param fragments: Text fragments to tokenize and concatenate. May be empty.
:param weights: Per-fragment weights (i.e. quasi-CFG scaling). Values from 0 to inf are permitted. In practise with SD1.5
values >1.6 tend to produce garbage output. Must have same length as `fragment`.
:return: A tuple of tensors `(token_ids, weights)`. `token_ids` is ints, `weights` is floats, both have shape `[self.max_length]`.
'''
if len(fragments) != len(weights):
raise ValueError(f"lengths of text and fragment_weights lists are not the same ({len(fragments)} != {len(weights)})")
# empty is meaningful
if len(fragments) == 0:
fragments = ['']
weights = [1.0]
per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False)
all_token_ids = []
per_token_weights = []
#print("all fragments:", fragments, weights)
for this_fragment_token_ids, weight in zip(per_fragment_token_ids, weights):
# append
all_token_ids += this_fragment_token_ids
# fill out weights tensor with one float per token
per_token_weights += [float(weight)] * len(this_fragment_token_ids)
# leave room for bos/eos
max_token_count_without_bos_eos_markers = self.max_length - 2
if len(all_token_ids) > max_token_count_without_bos_eos_markers:
excess_token_count = len(all_token_ids) - max_token_count_without_bos_eos_markers
# TODO build nice description string of how the truncation was applied
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers]
per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers]
# pad out to a self.max_length-entry array: [bos_token, <prompt tokens>, eos_token, pad_token…]
# (typically self.max_length == 77)
all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id]
per_token_weights = [1.0] + per_token_weights + [1.0]
pad_length = self.max_length - len(all_token_ids)
all_token_ids += [self.tokenizer.pad_token_id] * pad_length
per_token_weights += [1.0] * pad_length
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch_dtype(self.text_encoder.device), device=device)
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
return all_token_ids_tensor, per_token_weights_tensor
def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor) -> torch.Tensor:
'''
Build a tensor that embeds the passed-in token IDs and applies the given per_token weights
:param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints)
:param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats)
:return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings
where `token_dim` is 768 for SD1 and 1280 for SD2.
'''
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
if token_ids.shape != torch.Size([self.max_length]):
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")
z = self.text_encoder(token_ids.unsqueeze(0), return_dict=False)[0]
empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] +
[self.tokenizer.pad_token_id] * (self.max_length-2) +
[self.tokenizer.eos_token_id], dtype=torch.int, device=z.device).unsqueeze(0)
empty_z = self.text_encoder(empty_token_ids).last_hidden_state
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape).to(z)
z_delta_from_empty = z - empty_z
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
return weighted_z

View File

@ -8,6 +8,7 @@ import torch
from picklescan.scanner import scan_file_path
from transformers import CLIPTextModel, CLIPTokenizer
from compel.embeddings_provider import BaseTextualInversionManager
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
@ -23,7 +24,7 @@ class TextualInversion:
return self.embedding.shape[0]
class TextualInversionManager:
class TextualInversionManager(BaseTextualInversionManager):
def __init__(
self,
tokenizer: CLIPTokenizer,
@ -34,6 +35,7 @@ class TextualInversionManager:
self.text_encoder = text_encoder
self.full_precision = full_precision
self.hf_concepts_library = HuggingFaceConceptsLibrary()
self.trigger_to_sourcefile = dict()
default_textual_inversions: list[TextualInversion] = []
self.textual_inversions = default_textual_inversions
@ -59,15 +61,17 @@ class TextualInversionManager:
def get_all_trigger_strings(self) -> list[str]:
return [ti.trigger_string for ti in self.textual_inversions]
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
def load_textual_inversion(
self, ckpt_path: Union[str, Path], defer_injecting_tokens: bool = False
):
ckpt_path = Path(ckpt_path)
if not ckpt_path.is_file():
return
if str(ckpt_path).endswith(".DS_Store"):
return
try:
scan_result = scan_file_path(str(ckpt_path))
if scan_result.infected_files == 1:
@ -89,31 +93,49 @@ class TextualInversionManager:
return
elif (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
!= embedding_info['token_dim']
!= embedding_info["token_dim"]
):
print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
)
return
if embedding_info:
try:
self._add_textual_inversion(
embedding_info["name"],
embedding_info["embedding"],
defer_injecting_tokens=defer_injecting_tokens,
)
except ValueError as e:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
print(f" | The error was {str(e)}")
else:
print(
f">> Failed to load embedding located at {str(ckpt_path)}. Unsupported file."
# Resolve the situation in which an earlier embedding has claimed the same
# trigger string. We replace the trigger with '<source_file>', as we used to.
trigger_str = embedding_info["name"]
sourcefile = (
f"{ckpt_path.parent.name}/{ckpt_path.name}"
if ckpt_path.name == "learned_embeds.bin"
else ckpt_path.name
)
if trigger_str in self.trigger_to_sourcefile:
replacement_trigger_str = (
f"<{ckpt_path.parent.name}>"
if ckpt_path.name == "learned_embeds.bin"
else f"<{ckpt_path.stem}>"
)
print(
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
)
trigger_str = replacement_trigger_str
try:
self._add_textual_inversion(
trigger_str,
embedding_info["embedding"],
defer_injecting_tokens=defer_injecting_tokens,
)
# remember which source file claims this trigger
self.trigger_to_sourcefile[trigger_str] = sourcefile
except ValueError as e:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
print(f" | The error was {str(e)}")
def _add_textual_inversion(
self, trigger_str, embedding, defer_injecting_tokens=False
) -> TextualInversion:
) -> Optional[TextualInversion]:
"""
Add a textual inversion to be recognised.
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
@ -122,7 +144,7 @@ class TextualInversionManager:
"""
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
print(
f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
)
return
if not self.full_precision:
@ -131,7 +153,7 @@ class TextualInversionManager:
embedding = embedding.unsqueeze(0)
elif len(embedding.shape) > 2:
raise ValueError(
f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
f"** TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
)
try:
@ -147,7 +169,7 @@ class TextualInversionManager:
else:
traceback.print_exc()
print(
f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
)
raise
@ -294,7 +316,7 @@ class TextualInversionManager:
elif file_type == "bin":
return self._parse_embedding_bin(embedding_file)
else:
print(f">> Not a recognized embedding file: {embedding_file}")
print(f"** Notice: unrecognized embedding file format: {embedding_file}")
return None
def _parse_embedding_pt(self, embedding_file):
@ -355,8 +377,9 @@ class TextualInversionManager:
embedding_info = None
else:
for token in list(embedding_ckpt.keys()):
embedding_info["name"] = token or os.path.basename(
os.path.splitext(embedding_file)[0]
embedding_info["name"] = (
token
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
)
embedding_info["embedding"] = embedding_ckpt[token]
embedding_info[
@ -380,7 +403,7 @@ class TextualInversionManager:
embedding_info["name"] = (
token
if token != "*"
else os.path.basename(os.path.splitext(embedding_file)[0])
else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
)
embedding_info["embedding"] = embedding_ckpt[
"string_to_param"

View File

@ -38,8 +38,9 @@ dependencies = [
"albumentations",
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel>=0.1.6",
"datasets",
"diffusers[torch]~=0.11",
"diffusers[torch]~=0.13",
"dnspython==2.2.1",
"einops",
"eventlet",

View File

@ -1,499 +0,0 @@
import unittest
import pyparsing
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \
Fragment
def parse_prompt(prompt_string):
pp = PromptParser()
#print(f"parsing '{prompt_string}'")
parse_result = pp.parse_conjunction(prompt_string)
#print(f"-> parsed '{prompt_string}' to {parse_result}")
return parse_result
def make_basic_conjunction(strings: list[str]):
fragments = [Fragment(x) for x in strings]
return Conjunction([FlattenedPrompt(fragments)])
def make_weighted_conjunction(weighted_strings: list[tuple[str,float]]):
fragments = [Fragment(x, w) for x,w in weighted_strings]
return Conjunction([FlattenedPrompt(fragments)])
class PromptParserTestCase(unittest.TestCase):
def test_empty(self):
self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt(''))
def test_basic(self):
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating"))
self.assertEqual(make_basic_conjunction(['Dalí']), parse_prompt("Dalí"))
def test_attention(self):
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames).attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("flames.attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("\"flames\".attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5"))
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames).attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames.attend(+)"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames).attend(+)"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\".attend(+)"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire flames.attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames).attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire \"flames\".attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++"))
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--"))
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++"))
self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1)]),
parse_prompt("(pretty flowers)+"))
self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1), (', the flames are too hot', 1)]),
parse_prompt("(pretty flowers)+, the flames are too hot"))
def test_no_parens_attention_runon(self):
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(1.1, 2))]), parse_prompt("fire flames++"))
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(0.9, 2))]), parse_prompt("fire flames--"))
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers fire++ flames"))
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers fire-- flames"))
def test_explicit_conjunction(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()'))
self.assertEqual(
Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("(fire)2.0", "flames-").and()'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]),
FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()'))
def test_conjunction_weights(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)'))
with self.assertRaises(PromptParser.ParsingException):
parse_prompt('("fire", "flames").and(2)')
parse_prompt('("fire", "flames").and(2,1,2)')
def test_complex_conjunction(self):
#print(parse_prompt("a person with a hat (riding a bicycle.swap(skateboard))++"))
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle)++\").and(0.5, 0.5)"))
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
FlattenedPrompt([("a person with a hat", 1.0),
("riding a", 1.1*1.1),
CrossAttentionControlSubstitute(
[Fragment("bicycle", pow(1.1,2))],
[Fragment("skateboard", pow(1.1,2))])
])
], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)"))
def test_badly_formed(self):
def make_untouched_prompt(prompt):
return Conjunction([FlattenedPrompt([(prompt, 1.0)])])
def assert_if_prompt_string_not_untouched(prompt):
self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt))
assert_if_prompt_string_not_untouched('a test prompt')
assert_if_prompt_string_not_untouched('a badly formed +test prompt')
assert_if_prompt_string_not_untouched('a badly (formed test prompt')
#with self.assertRaises(pyparsing.ParseException):
assert_if_prompt_string_not_untouched('a badly (formed +test prompt')
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(((a badly formed +test prompt',1)])]) , parse_prompt('(((a badly (formed +test )prompt'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test prompt'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test +prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test +prompt'))
self.assertEqual(Conjunction([Blend([FlattenedPrompt([Fragment('((a badly (formed +test', 1)])], [1.0])]),
parse_prompt('("((a badly (formed +test ").blend(1.0)'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
parse_prompt("hamburger ((bun))"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
parse_prompt("hamburger (bun)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]),
parse_prompt("hamburger (kaiser roll)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]),
parse_prompt("hamburger ((kaiser roll))"))
def test_blend(self):
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
parse_prompt("(\"mountain\", \"man\").blend()")
)
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
parse_prompt("(mountain, man).blend()")
)
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
parse_prompt("((mountain), (man)).blend()")
)
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('tall man', 1.0)])], [1.0, 1.0])]),
parse_prompt("((mountain), (tall man)).blend()")
)
with self.assertRaises(PromptParser.ParsingException):
print(parse_prompt("((mountain), \"cat.swap(dog)\").blend()"))
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
)
self.assertEqual(Conjunction([Blend(
[FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])],
[0.7, 0.3, 1.0])]),
parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)")
)
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]),
FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]),
FlattenedPrompt([('hi', 1.0)])],
weights=[0.7, 0.3, 1.0])]),
parse_prompt("(\"fire\", \"fire flames (hot)++\", \"hi\").blend(0.7, 0.3, 1.0)")
)
# blend a single entry is not a failure
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]),
parse_prompt("(\"fire\").blend(0.7)")
)
# blend with empty
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \"\").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" , \").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0], normalize_weights=True)]),
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)')
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9 * 0.9)])], weights=[1.0, -1.0], normalize_weights=False)]),
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1,no_normalize)')
)
with self.assertRaises(PromptParser.ParsingException):
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3, 0.1)")
with self.assertRaises(PromptParser.ParsingException):
parse_prompt("(\"fire\", \"fire flames\").blend(0.7)")
def test_nested(self):
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]),
parse_prompt('fire (flames (trees)1.5)2.0'))
self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]),
FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])],
weights=[1.0, 1.0])]),
parse_prompt('("fire (flames)++", "mountain (man)2").blend(1,1)'))
def test_cross_attention_control(self):
self.assertEqual(Conjunction([FlattenedPrompt([CrossAttentionControlSubstitute([Fragment('sun')], [Fragment('moon')])])]),
parse_prompt("sun.swap(moon)"))
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog"))
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog"))
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap("trees")'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap("trees")'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")'))
fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])])
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")'))
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")'))
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")'))
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])])
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap(flames)'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap(flames)'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]),
(', fire', 1.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape "".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape ().swap(in winter)'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape " ".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap("")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap()'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap(" ")'))
def test_cross_attention_control_with_attention(self):
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"(flames)0.5".swap("(trees)0.7"), (fire)2.0'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7"), (fire)2.0'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1),
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
])]),
parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1),
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
])]),
parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++, shape_freedom=0.5)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1),
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
])]),
parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"hotdog++++\", shape_freedom=0.5)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1),
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))])
])]),
parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog++++, shape_freedom=0.5)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1),
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))])
])]),
parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"h\(o\)tdog++++\", shape_freedom=0.5)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1),
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(0.9,1))])
])]),
parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog-, shape_freedom=0.5)"))
def test_cross_attention_control_options(self):
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start':0.1}),
Fragment('eating a hotdog', 1)])]),
parse_prompt("a \"cat\".swap(dog, s_start=0.1) eating a hotdog"))
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'t_start':0.1}),
Fragment('eating a hotdog', 1)])]),
parse_prompt("a \"cat\".swap(dog, t_start=0.1) eating a hotdog"))
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start': 20.0, 't_start':0.1}),
Fragment('eating a hotdog', 1)])]),
parse_prompt("a \"cat\".swap(dog, t_start=0.1, s_start=20) eating a hotdog"))
self.assertEqual(
Conjunction([
FlattenedPrompt([Fragment('a fantasy forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('', 1)], [Fragment('with a river', 1)],
options={'s_start': 0.8, 't_start': 0.8})])]),
parse_prompt("a fantasy forest landscape \"\".swap(with a river, s_start=0.8, t_start=0.8)"))
def test_escaping(self):
# make sure ", ( and ) can be escaped
self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain \(man\)'))
self.assertEqual(make_basic_conjunction(['mountain (man )']),parse_prompt('mountain (\(man)\)'))
self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain (\(man\))'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain (\(man\))+'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" (\(man\))+'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" (\(man\))+'))
# same weights for each are combined into one
self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('(\\"mountain\\")+ (\(man\))+'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('(\\"mountain\\")+ (\(man\))-'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain (\(man\))1.1'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" (\(man\))1.1'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" (\(man\))1.1'))
# same weights for each are combined into one
self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('(\\"mountain\\")+ (\(man\))1.1'))
self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('(\\"mountain\\")1.1 (\(man\))0.9'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man', 1.1)]),parse_prompt('hairy ("mountain, man")+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain , \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain , \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('mountain , \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain , \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy'))
self.assertEqual(make_weighted_conjunction([('mountain , \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain , \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain , \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mou)ntain , \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry '))
self.assertEqual(make_weighted_conjunction([('mountain , \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( '))
def test_cross_attention_escaping(self):
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
parse_prompt('mountain (man).swap(monkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]),
parse_prompt('mountain (man).swap(m\(onkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]),
parse_prompt('mountain (m\(an).swap(m\(onkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]),
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
parse_prompt('mountain ("man").swap(monkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
parse_prompt('mountain ("man").swap("monkey")'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]),
parse_prompt('mountain (\\"man).swap("monkey")'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]),
parse_prompt('mountain (man).swap(m\(onkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]),
parse_prompt('mountain (m\(an).swap(m\(onkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]),
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
def test_legacy_blend(self):
pp = PromptParser()
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
FlattenedPrompt([('man mountain', 1)])],
weights=[0.5,0.5]),
pp.parse_legacy_blend('mountain man:1 man mountain:1'))
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
weights=[0.5,0.5]),
pp.parse_legacy_blend('mountain+ man:1 man mountain-:1'))
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
weights=[0.5,0.5]),
pp.parse_legacy_blend('mountain+ man:1 man mountain-'))
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
weights=[0.5,0.5]),
pp.parse_legacy_blend('mountain+ man: man mountain-:'))
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
FlattenedPrompt([('man mountain', 1)])],
weights=[0.75,0.25]),
pp.parse_legacy_blend('mountain man:3 man mountain:1'))
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
FlattenedPrompt([('man mountain', 1)])],
weights=[1.0,0.0]),
pp.parse_legacy_blend('mountain man:3 man mountain:0'))
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
FlattenedPrompt([('man mountain', 1)])],
weights=[0.8,0.2]),
pp.parse_legacy_blend('"mountain man":4 man mountain'))
def test_single(self):
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
FlattenedPrompt([("a person with a hat", 1.0),
("riding a", 1.1*1.1),
CrossAttentionControlSubstitute(
[Fragment("bicycle", pow(1.1,2))],
[Fragment("skateboard", pow(1.1,2))])
])
], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)"))
pass
if __name__ == '__main__':
unittest.main()