Merge branch 'development' of github.com:invoke-ai/InvokeAI into development

This commit is contained in:
Lincoln Stein 2022-11-01 17:40:36 -04:00
commit 533fd04ef0
20 changed files with 14144 additions and 3483 deletions

View File

@ -605,9 +605,8 @@ class InvokeAIWebServer:
progress.set_current_status_has_steps(True) progress.set_current_status_has_steps(True)
if ( if (
generation_parameters["progress_images"] generation_parameters['progress_images'] and step % 5 == 0 \
and step % 5 == 0 and step < generation_parameters['steps'] - 1
and step < generation_parameters["steps"] - 1
): ):
image = self.generate.sample_to_image(sample) image = self.generate.sample_to_image(sample)
metadata = self.parameters_to_generated_image_metadata( metadata = self.parameters_to_generated_image_metadata(
@ -637,6 +636,25 @@ class InvokeAIWebServer:
"height": height, "height": height,
}, },
) )
if generation_parameters['progress_latents']:
image = self.generate.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_base64 = "data:image/jpeg;base64," + base64.b64encode(buffered.getvalue()).decode('UTF-8')
self.socketio.emit(
"intermediateResult",
{
"url": img_base64,
"isBase64": True,
"mtime": 0,
"metadata": {},
"width": width,
"height": height,
}
)
self.socketio.emit("progressUpdate", progress.to_formatted_dict()) self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0) eventlet.sleep(0)

File diff suppressed because one or more lines are too long

517
frontend/dist/assets/index.cc049b93.js vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -6,7 +6,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>InvokeAI - A Stable Diffusion Toolkit</title> <title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="shortcut icon" type="icon" href="./assets/favicon.0d253ced.ico" /> <link rel="shortcut icon" type="icon" href="./assets/favicon.0d253ced.ico" />
<script type="module" crossorigin src="./assets/index.044a626e.js"></script> <script type="module" crossorigin src="./assets/index.cc049b93.js"></script>
<link rel="stylesheet" href="./assets/index.52c8231e.css"> <link rel="stylesheet" href="./assets/index.52c8231e.css">
</head> </head>

10651
frontend/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -37,3 +37,9 @@ export const NUMPY_RAND_MIN = 0;
export const NUMPY_RAND_MAX = 4294967295; export const NUMPY_RAND_MAX = 4294967295;
export const FACETOOL_TYPES = ['gfpgan', 'codeformer'] as const; export const FACETOOL_TYPES = ['gfpgan', 'codeformer'] as const;
export const IN_PROGRESS_IMAGE_TYPES: Array<{ key: string; value: string }> = [
{ key: "None", value: 'none'},
{ key: "Fast", value: 'latents' },
{ key: "Accurate", value: 'full-res' }
];

View File

@ -115,7 +115,8 @@ export declare type Image = {
metadata?: Metadata; metadata?: Metadata;
width: number; width: number;
height: number; height: number;
category: GalleryCategory; category: GalleryCategory;
isBase64: boolean;
}; };
// GalleryImages is an array of Image. // GalleryImages is an array of Image.

View File

@ -261,18 +261,20 @@ const makeSocketIOListeners = (
const { intermediateImage } = getState().gallery; const { intermediateImage } = getState().gallery;
if (intermediateImage) { if (intermediateImage) {
dispatch( if (!intermediateImage.isBase64) {
addImage({ dispatch(
category: 'result', addImage({
image: intermediateImage, category: 'result',
}) image: intermediateImage,
); })
dispatch( );
addLogEntry({ dispatch(
timestamp: dateFormat(new Date(), 'isoDateTime'), addLogEntry({
message: `Intermediate image saved: ${intermediateImage.url}`, timestamp: dateFormat(new Date(), 'isoDateTime'),
}) message: `Intermediate image saved: ${intermediateImage.url}`,
); })
);
}
dispatch(clearIntermediateImage()); dispatch(clearIntermediateImage());
} }

View File

@ -62,7 +62,7 @@ export const frontendToBackendParameters = (
shouldRandomizeSeed, shouldRandomizeSeed,
} = optionsState; } = optionsState;
const { shouldDisplayInProgress } = systemState; const { shouldDisplayInProgressType } = systemState;
const generationParameters: { [k: string]: any } = { const generationParameters: { [k: string]: any } = {
prompt, prompt,
@ -75,7 +75,8 @@ export const frontendToBackendParameters = (
width, width,
sampler_name: sampler, sampler_name: sampler,
seed, seed,
progress_images: shouldDisplayInProgress, progress_images: shouldDisplayInProgressType === 'full-res',
progress_latents: shouldDisplayInProgressType === 'latents'
}; };
generationParameters.seed = shouldRandomizeSeed generationParameters.seed = shouldRandomizeSeed

View File

@ -44,13 +44,13 @@ const systemSelector = createSelector(
[ [
(state: RootState) => state.system, (state: RootState) => state.system,
(state: RootState) => state.options, (state: RootState) => state.options,
intermediateImageSelector, (state: RootState) => state.gallery,
activeTabNameSelector, activeTabNameSelector,
], ],
( (
system: SystemState, system: SystemState,
options: OptionsState, options: OptionsState,
intermediateImage, gallery: GalleryState,
activeTabName activeTabName
) => { ) => {
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } = const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
@ -59,6 +59,8 @@ const systemSelector = createSelector(
const { upscalingLevel, facetoolStrength, shouldShowImageDetails } = const { upscalingLevel, facetoolStrength, shouldShowImageDetails } =
options; options;
const { intermediateImage } = gallery;
return { return {
isProcessing, isProcessing,
isConnected, isConnected,

View File

@ -20,23 +20,24 @@ import { persistor } from '../../../main';
import { import {
setShouldConfirmOnDelete, setShouldConfirmOnDelete,
setShouldDisplayGuides, setShouldDisplayGuides,
setShouldDisplayInProgress, setShouldDisplayInProgressType,
SystemState, SystemState,
} from '../systemSlice'; } from '../systemSlice';
import ModelList from './ModelList'; import ModelList from './ModelList';
import SettingsModalItem from './SettingsModalItem'; import { SettingsModalItem, SettingsModalSelectItem } from './SettingsModalItem';
import { IN_PROGRESS_IMAGE_TYPES } from '../../../app/constants';
const systemSelector = createSelector( const systemSelector = createSelector(
(state: RootState) => state.system, (state: RootState) => state.system,
(system: SystemState) => { (system: SystemState) => {
const { const {
shouldDisplayInProgress, shouldDisplayInProgressType,
shouldConfirmOnDelete, shouldConfirmOnDelete,
shouldDisplayGuides, shouldDisplayGuides,
model_list, model_list,
} = system; } = system;
return { return {
shouldDisplayInProgress, shouldDisplayInProgressType,
shouldConfirmOnDelete, shouldConfirmOnDelete,
shouldDisplayGuides, shouldDisplayGuides,
models: _.map(model_list, (_model, key) => key), models: _.map(model_list, (_model, key) => key),
@ -72,7 +73,7 @@ const SettingsModal = ({ children }: SettingsModalProps) => {
} = useDisclosure(); } = useDisclosure();
const { const {
shouldDisplayInProgress, shouldDisplayInProgressType,
shouldConfirmOnDelete, shouldConfirmOnDelete,
shouldDisplayGuides, shouldDisplayGuides,
} = useAppSelector(systemSelector); } = useAppSelector(systemSelector);
@ -102,10 +103,12 @@ const SettingsModal = ({ children }: SettingsModalProps) => {
<ModalBody className="settings-modal-content"> <ModalBody className="settings-modal-content">
<ModelList /> <ModelList />
<div className="settings-modal-items"> <div className="settings-modal-items">
<SettingsModalItem
settingTitle="Display In-Progress Images (slower)" <SettingsModalSelectItem
isChecked={shouldDisplayInProgress} settingTitle="Display In-Progress Images"
dispatcher={setShouldDisplayInProgress} validValues={IN_PROGRESS_IMAGE_TYPES}
defaultValue={shouldDisplayInProgressType}
dispatcher={setShouldDisplayInProgressType}
/> />
<SettingsModalItem <SettingsModalItem

View File

@ -1,7 +1,8 @@
import { useAppDispatch } from '../../../app/store'; import { useAppDispatch } from '../../../app/store';
import IAISelect from '../../../common/components/IAISelect';
import IAISwitch from '../../../common/components/IAISwitch'; import IAISwitch from '../../../common/components/IAISwitch';
export default function SettingsModalItem({ export function SettingsModalItem({
settingTitle, settingTitle,
isChecked, isChecked,
dispatcher, dispatcher,
@ -20,3 +21,30 @@ export default function SettingsModalItem({
/> />
); );
} }
export function SettingsModalSelectItem({
settingTitle,
validValues,
defaultValue,
dispatcher,
}: {
settingTitle: string;
validValues:
Array<number | string>
| Array<{ key: string; value: string | number }>;
defaultValue: string;
dispatcher: any;
}) {
const dispatch = useAppDispatch();
return (
<IAISelect
styleClass="settings-modal-item"
label={settingTitle}
validValues={validValues}
defaultValue={defaultValue}
onChange={(e) => dispatch(dispatcher(e.target.value))}
/>
);
}

View File

@ -18,7 +18,7 @@ export interface Log {
export interface SystemState export interface SystemState
extends InvokeAI.SystemStatus, extends InvokeAI.SystemStatus,
InvokeAI.SystemConfig { InvokeAI.SystemConfig {
shouldDisplayInProgress: boolean; shouldDisplayInProgressType: string;
log: Array<LogEntry>; log: Array<LogEntry>;
shouldShowLogViewer: boolean; shouldShowLogViewer: boolean;
isGFPGANAvailable: boolean; isGFPGANAvailable: boolean;
@ -43,7 +43,7 @@ const initialSystemState = {
isProcessing: false, isProcessing: false,
log: [], log: [],
shouldShowLogViewer: false, shouldShowLogViewer: false,
shouldDisplayInProgress: false, shouldDisplayInProgressType: "none",
shouldDisplayGuides: true, shouldDisplayGuides: true,
isGFPGANAvailable: true, isGFPGANAvailable: true,
isESRGANAvailable: true, isESRGANAvailable: true,
@ -73,8 +73,8 @@ export const systemSlice = createSlice({
name: 'system', name: 'system',
initialState, initialState,
reducers: { reducers: {
setShouldDisplayInProgress: (state, action: PayloadAction<boolean>) => { setShouldDisplayInProgressType: (state, action: PayloadAction<string>) => {
state.shouldDisplayInProgress = action.payload; state.shouldDisplayInProgressType = action.payload;
}, },
setIsProcessing: (state, action: PayloadAction<boolean>) => { setIsProcessing: (state, action: PayloadAction<boolean>) => {
state.isProcessing = action.payload; state.isProcessing = action.payload;
@ -182,7 +182,7 @@ export const systemSlice = createSlice({
}); });
export const { export const {
setShouldDisplayInProgress, setShouldDisplayInProgressType,
setIsProcessing, setIsProcessing,
addLogEntry, addLogEntry,
setShouldShowLogViewer, setShouldShowLogViewer,

File diff suppressed because it is too large Load Diff

View File

@ -913,6 +913,9 @@ class Generate:
def sample_to_image(self, samples): def sample_to_image(self, samples):
return self._make_base().sample_to_image(samples) return self._make_base().sample_to_image(samples)
def sample_to_lowres_estimated_image(self, samples):
return self._make_base().sample_to_lowres_estimated_image(samples)
# very repetitive code - can this be simplified? The KSampler names are # very repetitive code - can this be simplified? The KSampler names are
# consistent, at least # consistent, at least
def _set_sampler(self): def _set_sampler(self):

View File

@ -116,6 +116,29 @@ class Generator():
) )
return Image.fromarray(x_sample.astype(np.uint8)) return Image.fromarray(x_sample.astype(np.uint8))
# write an approximate RGB image from latent samples for a single step to PNG
def sample_to_lowres_estimated_image(self,samples):
# adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these numbers were determined empirically by @keturn
v1_4_latent_rgb_factors = torch.tensor([
# R G B
[ 0.298, 0.207, 0.208], # L1
[ 0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
], dtype=samples.dtype, device=samples.device)
latent_image = samples[0].permute(1, 2, 0) @ v1_4_latent_rgb_factors
latents_ubyte = (((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()).cpu()
return Image.fromarray(latents_ubyte.numpy())
def generate_initial_noise(self, seed, width, height): def generate_initial_noise(self, seed, width, height):
initial_noise = None initial_noise = None
if self.variation_amount > 0 or len(self.with_variations) > 0: if self.variation_amount > 0 or len(self.with_variations) > 0:

View File

@ -28,7 +28,7 @@ class Prompt():
def __init__(self, parts: list): def __init__(self, parts: list):
for c in parts: for c in parts:
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults: if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed") raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} ({c}), only {[c.__name__ for c in BaseFragment.__subclasses__()]} are allowed")
self.children = parts self.children = parts
def __repr__(self): def __repr__(self):
return f"Prompt:{self.children}" return f"Prompt:{self.children}"
@ -102,12 +102,18 @@ class Attention():
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object. Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
""" """
def __init__(self, weight: float, children: list): def __init__(self, weight: float, children: list):
if type(weight) is not float:
raise PromptParser.ParsingException(
f"Attention weight must be float (got {type(weight).__name__} {weight})")
self.weight = weight self.weight = weight
if type(children) is not list:
raise PromptParser.ParsingException(f"cannot make Attention with non-list of children (got {type(children)})")
assert(type(children) is list)
self.children = children self.children = children
#print(f"A: requested attention '{children}' to {weight}") #print(f"A: requested attention '{children}' to {weight}")
def __repr__(self): def __repr__(self):
return f"Attention:'{self.children}' @ {self.weight}" return f"Attention:{self.children} * {self.weight}"
def __eq__(self, other): def __eq__(self, other):
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
@ -136,9 +142,9 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
Fragment('sitting on a car') Fragment('sitting on a car')
]) ])
""" """
def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None): def __init__(self, original: list, edited: list, options: dict=None):
self.original = original self.original = original
self.edited = edited self.edited = edited if len(edited)>0 else [Fragment('')]
default_options = { default_options = {
's_start': 0.0, 's_start': 0.0,
@ -190,12 +196,12 @@ class Conjunction():
""" """
def __init__(self, prompts: list, weights: list = None): def __init__(self, prompts: list, weights: list = None):
# force everything to be a Prompt # force everything to be a Prompt
#print("making conjunction with", parts) #print("making conjunction with", prompts, "types", [type(p).__name__ for p in prompts])
self.prompts = [x if (type(x) is Prompt self.prompts = [x if (type(x) is Prompt
or type(x) is Blend or type(x) is Blend
or type(x) is FlattenedPrompt) or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts] else Prompt(x) for x in prompts]
self.weights = [1.0]*len(self.prompts) if weights is None else list(weights) self.weights = [1.0]*len(self.prompts) if (weights is None or len(weights)==0) else list(weights)
if len(self.weights) != len(self.prompts): if len(self.weights) != len(self.prompts):
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}") raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
self.type = 'AND' self.type = 'AND'
@ -216,6 +222,7 @@ class Blend():
""" """
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True): def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
#print("making Blend with prompts", prompts, "and weights", weights) #print("making Blend with prompts", prompts, "and weights", weights)
weights = [1.0]*len(prompts) if (weights is None or len(weights)==0) else list(weights)
if len(prompts) != len(weights): if len(prompts) != len(weights):
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}") raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
for p in prompts: for p in prompts:
@ -244,6 +251,10 @@ class PromptParser():
class ParsingException(Exception): class ParsingException(Exception):
pass pass
class UnrecognizedOperatorException(ParsingException):
def __init__(self, operator:str):
super().__init__("Unrecognized operator: " + operator)
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9): def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base) self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
@ -279,7 +290,7 @@ class PromptParser():
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True) return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
def flatten(self, root: Conjunction) -> Conjunction: def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
""" """
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends, Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
@ -289,8 +300,6 @@ class PromptParser():
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root. :return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
""" """
#print("flattening", root)
def fuse_fragments(items): def fuse_fragments(items):
# print("fusing fragments in ", items) # print("fusing fragments in ", items)
result = [] result = []
@ -313,8 +322,8 @@ class PromptParser():
return result return result
def flatten_internal(node, weight_scale, results, prefix): def flatten_internal(node, weight_scale, results, prefix):
#print(prefix + "flattening", node, "...") verbose and print(prefix + "flattening", node, "...")
if type(node) is pp.ParseResults: if type(node) is pp.ParseResults or type(node) is list:
for x in node: for x in node:
results = flatten_internal(x, weight_scale, results, prefix+' pr ') results = flatten_internal(x, weight_scale, results, prefix+' pr ')
#print(prefix, " ParseResults expanded, results is now", results) #print(prefix, " ParseResults expanded, results is now", results)
@ -345,67 +354,59 @@ class PromptParser():
#print(prefix + "after flattening Prompt, results is", results) #print(prefix + "after flattening Prompt, results is", results)
else: else:
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
#print(prefix + "-> after flattening", type(node).__name__, "results is", results) verbose and print(prefix + "-> after flattening", type(node).__name__, "results is", results)
return results return results
verbose and print("flattening", root)
flattened_parts = [] flattened_parts = []
for part in root.prompts: for part in root.prompts:
flattened_parts += flatten_internal(part, 1.0, [], ' C| ') flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
#print("flattened to", flattened_parts) verbose and print("flattened to", flattened_parts)
weights = root.weights weights = root.weights
return Conjunction(flattened_parts, weights) return Conjunction(flattened_parts, weights)
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float): def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
def make_operator_object(x):
#print('making operator for', x)
target = x[0]
operator = x[1]
arguments = x[2]
if operator == '.attend':
weight_raw = arguments[0]
weight = 1.0
if type(weight_raw) is float or type(weight_raw) is int:
weight = weight_raw
elif type(weight_raw) is str:
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
weight = pow(base, len(weight_raw))
return Attention(weight=weight, children=[x for x in x[0]])
elif operator == '.swap':
return CrossAttentionControlSubstitute(target, arguments, x.as_dict())
elif operator == '.blend':
prompts = [Prompt(p) for p in x[0]]
weights_raw = x[2]
normalize_weights = True
if len(weights_raw) > 0 and weights_raw[-1][0] == 'no_normalize':
normalize_weights = False
weights_raw = weights_raw[:-1]
weights = [float(w[0]) for w in weights_raw]
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize_weights)
elif operator == '.and' or operator == '.add':
prompts = [Prompt(p) for p in x[0]]
weights = [float(w[0]) for w in x[2]]
return Conjunction(prompts=prompts, weights=weights)
lparen = pp.Literal("(").suppress() raise PromptParser.UnrecognizedOperatorException(operator)
rparen = pp.Literal(")").suppress()
quotes = pp.Literal('"').suppress()
comma = pp.Literal(",").suppress()
# accepts int or float notation, always maps to float def parse_fragment_str(x, expression: pp.ParseExpression, in_quotes: bool = False, in_parens: bool = False):
number = pp.pyparsing_common.real | \
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
attention = pp.Forward()
quoted_fragment = pp.Forward()
parenthesized_fragment = pp.Forward()
cross_attention_substitute = pp.Forward()
def make_text_fragment(x):
#print("### making fragment for", x)
if type(x[0]) is Fragment:
assert(False)
if type(x) is str:
return Fragment(x)
elif type(x) is pp.ParseResults or type(x) is list:
#print(f'converting {type(x).__name__} to Fragment')
return Fragment(' '.join([s for s in x]))
else:
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str):
escapes = []
for c in escaped_chars_to_ignore:
escapes.append(pp.Literal('\\'+c))
return pp.Combine(pp.OneOrMore(
pp.MatchFirst(escapes + [pp.CharsNotIn(
string.whitespace + escaped_chars_to_ignore,
exact=1
)])
))
def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
#print(f"parsing fragment string for {x}") #print(f"parsing fragment string for {x}")
fragment_string = x[0] fragment_string = x[0]
#print(f"ppparsing fragment string \"{fragment_string}\"")
if len(fragment_string.strip()) == 0: if len(fragment_string.strip()) == 0:
return Fragment('') return Fragment('')
@ -413,234 +414,198 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
# escape unescaped quotes # escape unescaped quotes
fragment_string = fragment_string.replace('"', '\\"') fragment_string = fragment_string.replace('"', '\\"')
#fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
try: try:
result = pp.Group(pp.MatchFirst([ result = (expression + pp.StringEnd()).parse_string(fragment_string)
pp.OneOrMore(quoted_fragment | attention | unquoted_word).set_name('pf_str_qfuq'),
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
])).set_name('blend-result').set_debug(False).parse_string(fragment_string)
#print("parsed to", result) #print("parsed to", result)
return result return result
except pp.ParseException as e: except pp.ParseException as e:
#print("parse_fragment_str couldn't parse prompt string:", e) #print("parse_fragment_str couldn't parse prompt string:", e)
raise raise
# meaningful symbols
lparen = pp.Literal("(").suppress()
rparen = pp.Literal(")").suppress()
quote = pp.Literal('"').suppress()
comma = pp.Literal(",").suppress()
dot = pp.Literal(".").suppress()
equals = pp.Literal("=").suppress()
escaped_lparen = pp.Literal('\\(')
escaped_rparen = pp.Literal('\\)')
escaped_quote = pp.Literal('\\"')
escaped_comma = pp.Literal('\\,')
escaped_dot = pp.Literal('\\.')
escaped_plus = pp.Literal('\\+')
escaped_minus = pp.Literal('\\-')
escaped_equals = pp.Literal('\\=')
syntactic_symbols = {
'(': escaped_lparen,
')': escaped_rparen,
'"': escaped_quote,
',': escaped_comma,
'.': escaped_dot,
'+': escaped_plus,
'-': escaped_minus,
'=': escaped_equals,
}
syntactic_chars = "".join(syntactic_symbols.keys())
# accepts int or float notation, always maps to float
number = pp.pyparsing_common.real | \
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
# for options
keyword = pp.Word(pp.alphanums + '_')
# a word that absolutely does not contain any meaningful syntax
non_syntax_word = pp.Combine(pp.OneOrMore(pp.MatchFirst([
pp.Or(syntactic_symbols.values()),
pp.one_of(['-', '+']) + pp.NotAny(pp.White() | pp.Char(syntactic_chars) | pp.StringEnd()),
# build character-by-character
pp.CharsNotIn(string.whitespace + syntactic_chars, exact=1)
])))
non_syntax_word.set_parse_action(lambda x: [Fragment(t) for t in x])
non_syntax_word.set_name('non_syntax_word')
non_syntax_word.set_debug(False)
# a word that can contain any character at all - greedily consumes syntax, so use with care
free_word = pp.CharsNotIn(string.whitespace).set_parse_action(lambda x: Fragment(x[0]))
free_word.set_name('free_word')
free_word.set_debug(False)
# ok here we go. forward declare some things..
attention = pp.Forward()
cross_attention_substitute = pp.Forward()
parenthesized_fragment = pp.Forward()
quoted_fragment = pp.Forward()
# the types of things that can go into a fragment, consisting of syntax-full and/or strictly syntax-free components
fragment_part_expressions = [
attention,
cross_attention_substitute,
parenthesized_fragment,
quoted_fragment,
non_syntax_word
]
# a fragment that is permitted to contain commas
fragment_including_commas = pp.ZeroOrMore(pp.MatchFirst(
fragment_part_expressions + [
pp.Literal(',').set_parse_action(lambda x: Fragment(x[0]))
]
))
# a fragment that is not permitted to contain commas
fragment_excluding_commas = pp.ZeroOrMore(pp.MatchFirst(
fragment_part_expressions
))
# a fragment in double quotes (may be nested)
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"') quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment') quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, fragment_including_commas, in_quotes=True))
escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"') # a fragment inside parentheses (may be nested)
escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(') parenthesized_fragment << (lparen + fragment_including_commas + rparen)
escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')') parenthesized_fragment.set_name('parenthesized_fragment')
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"') parenthesized_fragment.set_debug(False)
empty = ( # a string of the form (<keyword>=<float|keyword> | <float> | <keyword>) where keyword is alphanumeric + '_'
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) | option = pp.Group(pp.MatchFirst([
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty') keyword + equals + (number | keyword), # option=value
number.copy().set_parse_action(pp.token_map(str)), # weight
keyword # flag
def not_ends_with_swap(x):
#print("trying to match:", x)
return not x[0].endswith('.swap')
unquoted_word = (pp.Combine(pp.OneOrMore(
escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
(pp.CharsNotIn(string.whitespace + '\\"()', exact=1)
)))
# don't whitespace when the next word starts with +, eg "badly +formed"
+ (pp.White().suppress() |
# don't eat +/-
pp.NotAny(pp.Word('+') | pp.Word('-'))
)
)
unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False)
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
parenthesized_fragment << (lparen +
pp.Or([
(parenthesized_fragment),
(quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False)).set_name('-quoted_paren_internal').set_debug(False),
(pp.Combine(pp.OneOrMore(
escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
pp.CharsNotIn(string.whitespace + '\\"()', exact=1) |
pp.White()
)).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False)),
pp.Empty()
]) + rparen)
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
debug_attention = False
# attention control of the form (phrase)+ / (phrase)+ / (phrase)<weight>
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
attention_with_parens = pp.Forward()
attention_without_parens = pp.Forward()
attention_with_parens_foot = (number | pp.Word('+') | pp.Word('-'))\
.set_name("attention_foot")\
.set_debug(False)
attention_with_parens <<= pp.Group(
lparen +
pp.ZeroOrMore(quoted_fragment | attention_with_parens | parenthesized_fragment | cross_attention_substitute | attention_without_parens |
(pp.Empty() + build_escaped_word_parser_charbychar('()')).set_name('undecorated_word').set_debug(debug_attention)#.set_parse_action(lambda t: t[0])
)
+ rparen + attention_with_parens_foot)
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
attention_without_parens_foot = (pp.NotAny(pp.White()) + pp.Or([pp.Word('+'), pp.Word('-')]) + pp.FollowedBy(pp.StringEnd() | pp.White() | pp.Literal('(') | pp.Literal(')') | pp.Literal(',') | pp.Literal('"')) ).set_name('attention_without_parens_foots')
attention_without_parens <<= pp.Group(pp.MatchFirst([
quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot,
pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x))
+ attention_without_parens_foot#.leave_whitespace()
])) ]))
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention) # options for an operator, eg "s_start=0.1, 0.3, no_normalize"
options = pp.Dict(pp.Optional(pp.delimited_list(option)))
options.set_name('options')
options.set_debug(False)
# a fragment which can be used as the target for an operator - either quoted or in parentheses, or a bare vanilla word
potential_operator_target = (quoted_fragment | parenthesized_fragment | non_syntax_word)
attention << pp.MatchFirst([attention_with_parens, # a fragment whose weight has been increased or decreased by a given amount
attention_without_parens attention_weight_operator = pp.Word('+') | pp.Word('-') | number
]) attention_explicit = (
pp.Group(potential_operator_target)
+ pp.Literal('.attend')
+ lparen
+ pp.Group(attention_weight_operator)
+ rparen
)
attention_explicit.set_parse_action(make_operator_object)
attention_implicit = (
pp.Group(potential_operator_target)
+ pp.NotAny(pp.White()) # do not permit whitespace between term and operator
+ pp.Group(attention_weight_operator)
)
attention_implicit.set_parse_action(lambda x: make_operator_object([x[0], '.attend', x[1]]))
attention << (attention_explicit | attention_implicit)
attention.set_name('attention') attention.set_name('attention')
attention.set_debug(False)
def make_attention(x): # cross-attention control by swapping one fragment for another
#print("entered make_attention with", x) cross_attention_substitute << (
children = x[0][:-1] pp.Group(potential_operator_target).set_name('ca-target').set_debug(False)
weight_raw = x[0][-1] + pp.Literal(".swap").set_name('ca-operator').set_debug(False)
weight = 1.0 + lparen
if type(weight_raw) is float or type(weight_raw) is int: + pp.Group(fragment_excluding_commas).set_name('ca-replacement').set_debug(False)
weight = weight_raw + pp.Optional(comma + options).set_name('ca-options').set_debug(False)
elif type(weight_raw) is str: + rparen
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base )
weight = pow(base, len(weight_raw)) cross_attention_substitute.set_name('cross_attention_substitute')
cross_attention_substitute.set_debug(False)
#print("making Attention from", children, "with weight", weight) cross_attention_substitute.set_parse_action(make_operator_object)
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in children])
attention_with_parens.set_parse_action(make_attention)
attention_without_parens.set_parse_action(make_attention)
#print("parsing test:", attention_with_parens.parse_string("mountain (man)1.1"))
# cross-attention control
empty_string = ((lparen + rparen) |
pp.Literal('""').suppress() |
(lparen + pp.Literal('""').suppress() + rparen)
).set_parse_action(lambda x: Fragment(""))
empty_string.set_name('empty_string')
# cross attention control
debug_cross_attention_control = False
original_fragment = pp.MatchFirst([
quoted_fragment.set_debug(debug_cross_attention_control),
parenthesized_fragment.set_debug(debug_cross_attention_control),
pp.Combine(pp.OneOrMore(pp.CharsNotIn(string.whitespace + '.', exact=1))).set_parse_action(make_text_fragment) + pp.FollowedBy(".swap"),
empty_string.set_debug(debug_cross_attention_control),
])
# support keyword=number arguments
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number)
edited_fragment = pp.MatchFirst([
(lparen + rparen).set_parse_action(lambda x: Fragment('')),
lparen +
(quoted_fragment | attention |
pp.Group(pp.ZeroOrMore(build_escaped_word_parser_charbychar(',)').set_parse_action(make_text_fragment)))
) +
pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) +
rparen,
parenthesized_fragment
])
cross_attention_substitute << original_fragment + pp.Literal(".swap").set_debug(False).suppress() + edited_fragment
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control)
def make_cross_attention_substitute(x):
#print("making cacs for", x[0], "->", x[1], "with options", x.as_dict())
#if len(x>2):
cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict())
#print("made", cacs)
return cacs
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
# root prompt definition # an entire self-contained prompt, which can be used in a Blend or Conjunction
debug_root_prompt = False prompt = pp.ZeroOrMore(pp.MatchFirst([
prompt = (pp.OneOrMore(pp.MatchFirst([cross_attention_substitute.set_debug(debug_root_prompt), cross_attention_substitute,
attention.set_debug(debug_root_prompt), attention,
quoted_fragment.set_debug(debug_root_prompt), quoted_fragment,
parenthesized_fragment.set_debug(debug_root_prompt), parenthesized_fragment,
unquoted_word.set_debug(debug_root_prompt), free_word,
empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)]) pp.White().suppress()
) + pp.StringEnd()) \ ]))
.set_name('prompt') \ quoted_prompt = quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, prompt, in_quotes=True))
.set_parse_action(lambda x: Prompt(x)) \
.set_debug(debug_root_prompt)
#print("parsing test:", prompt.parse_string("spaced eyes--"))
#print("parsing test:", prompt.parse_string("eyes--"))
# weighted blend of prompts # a blend/lerp between the feature vectors for two or more prompts
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or blend = (
# int weights. lparen
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c) + pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('bl-target').set_debug(False)
+ rparen
+ pp.Literal(".blend").set_name('bl-operator').set_debug(False)
+ lparen
+ pp.Group(options).set_name('bl-options').set_debug(False)
+ rparen
)
blend.set_name('blend')
blend.set_debug(False)
blend.set_parse_action(make_operator_object)
def make_prompt_from_quoted_string(x): # an operator to direct stable diffusion to step multiple times, once for each target, and then add the results together with different weights
#print(' got quoted prompt', x) explicit_conjunction = (
lparen
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('cj-target').set_debug(False)
+ rparen
+ pp.one_of([".and", ".add"]).set_name('cj-operator').set_debug(False)
+ lparen
+ pp.Group(options).set_name('cj-options').set_debug(False)
+ rparen
)
explicit_conjunction.set_name('explicit_conjunction')
explicit_conjunction.set_debug(False)
explicit_conjunction.set_parse_action(make_operator_object)
x_unquoted = x[0][1:-1] # by default a prompt consists of a Conjunction with a single term
if len(x_unquoted.strip()) == 0: implicit_conjunction = (blend | pp.Group(prompt)) + pp.StringEnd()
# print(' b : just an empty string')
return Prompt([Fragment('')])
#print(f' b parsing \'{x_unquoted}\'')
x_parsed = prompt.parse_string(x_unquoted)
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
return x_parsed[0]
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
quoted_prompt.set_name('quoted_prompt')
debug_blend=False
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend)
blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend)
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
+ pp.Literal(".blend").suppress()
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
blend.set_debug(debug_blend)
def make_blend(x):
prompts = x[0][0]
weights = x[0][1]
normalize = True
if weights[-1] == 'no_normalize':
normalize = False
weights = weights[:-1]
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize)
blend.set_parse_action(make_blend)
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
+ pp.Literal(".and").suppress()
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
def make_conjunction(x):
parts_raw = x[0][0]
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
parts = [part for part in parts_raw]
return Conjunction(parts, weights)
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction')
implicit_conjunction.set_parse_action(lambda x: Conjunction(x)) implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction conjunction = (explicit_conjunction | implicit_conjunction)
conjunction.set_debug(False)
# top-level is a conjunction of one or more blends or prompts
return conjunction, prompt return conjunction, prompt
def split_weighted_subprompts(text, skip_normalize=False)->list: def split_weighted_subprompts(text, skip_normalize=False)->list:
""" """
Legacy blend parsing. Legacy blend parsing.

View File

@ -34,6 +34,7 @@ def build_opt(post_data, seed, gfpgan_model_exists):
setattr(opt, 'facetool_strength', float(post_data['facetool_strength']) if gfpgan_model_exists else 0) setattr(opt, 'facetool_strength', float(post_data['facetool_strength']) if gfpgan_model_exists else 0)
setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None) setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None)
setattr(opt, 'progress_images', 'progress_images' in post_data) setattr(opt, 'progress_images', 'progress_images' in post_data)
setattr(opt, 'progress_latents', 'progress_latents' in post_data)
setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed'])) setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))
setattr(opt, 'threshold', float(post_data['threshold'])) setattr(opt, 'threshold', float(post_data['threshold']))
setattr(opt, 'perlin', float(post_data['perlin'])) setattr(opt, 'perlin', float(post_data['perlin']))
@ -227,8 +228,13 @@ class DreamServer(BaseHTTPRequestHandler):
# since rendering images is moderately expensive, only render every 5th image # since rendering images is moderately expensive, only render every 5th image
# and don't bother with the last one, since it'll render anyway # and don't bother with the last one, since it'll render anyway
nonlocal step_index nonlocal step_index
if opt.progress_images and step % 5 == 0 and step < opt.steps - 1:
image = self.model.sample_to_image(sample) wants_progress_latents = opt.progress_latents
wants_progress_image = opt.progress_image and step % 5 == 0
if (wants_progress_image | wants_progress_latents) and step < opt.steps - 1:
image = self.model.sample_to_image(sample) if wants_progress_image \
else self.model.sample_to_lowres_estimated_image(sample)
step_index_padded = str(step_index).rjust(len(str(opt.steps)), '0') step_index_padded = str(step_index).rjust(len(str(opt.steps)), '0')
name = f'{prefix}.{opt.seed}.{step_index_padded}.png' name = f'{prefix}.{opt.seed}.{step_index_padded}.png'
metadata = f'{opt.prompt} -S{opt.seed} [intermediate]' metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'

View File

@ -39,6 +39,7 @@ class DreamBase():
model: str = None # The model to use (currently unused) model: str = None # The model to use (currently unused)
embeddings = None # The embeddings to use (currently unused) embeddings = None # The embeddings to use (currently unused)
progress_images: bool = False progress_images: bool = False
progress_latents: bool = False
# GFPGAN # GFPGAN
enable_gfpgan: bool enable_gfpgan: bool
@ -94,6 +95,7 @@ class DreamBase():
self.seamless = 'seamless' in j self.seamless = 'seamless' in j
self.hires_fix = 'hires_fix' in j self.hires_fix = 'hires_fix' in j
self.progress_images = 'progress_images' in j self.progress_images = 'progress_images' in j
self.progress_latents = 'progress_latents' in j
# GFPGAN # GFPGAN
self.enable_gfpgan = 'enable_gfpgan' in j and bool(j.get('enable_gfpgan')) self.enable_gfpgan = 'enable_gfpgan' in j and bool(j.get('enable_gfpgan'))

View File

@ -28,8 +28,8 @@ class PromptParserTestCase(unittest.TestCase):
self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt('')) self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt(''))
def test_basic(self): def test_basic(self):
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames")) self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames")) self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire")) self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating")) self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating"))
@ -37,14 +37,25 @@ class PromptParserTestCase(unittest.TestCase):
def test_attention(self): def test_attention(self):
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5")) self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames).attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("flames.attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("\"flames\".attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5")) self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5"))
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames).attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+")) self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+")) self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+")) self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames.attend(+)"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames).attend(+)"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\".attend(+)"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-")) self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-")) self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-")) self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5")) self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire flames.attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames).attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire \"flames\".attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++")) self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++"))
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--")) self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--"))
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++")) self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++"))
@ -102,20 +113,17 @@ class PromptParserTestCase(unittest.TestCase):
assert_if_prompt_string_not_untouched('a test prompt') assert_if_prompt_string_not_untouched('a test prompt')
assert_if_prompt_string_not_untouched('a badly formed +test prompt') assert_if_prompt_string_not_untouched('a badly formed +test prompt')
with self.assertRaises(pyparsing.ParseException): assert_if_prompt_string_not_untouched('a badly (formed test prompt')
parse_prompt('a badly (formed test prompt')
#with self.assertRaises(pyparsing.ParseException): #with self.assertRaises(pyparsing.ParseException):
with self.assertRaises(pyparsing.ParseException): assert_if_prompt_string_not_untouched('a badly (formed +test prompt')
parse_prompt('a badly (formed +test prompt')
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt')) self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt'))
with self.assertRaises(pyparsing.ParseException): self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(((a badly formed +test prompt',1)])]) , parse_prompt('(((a badly (formed +test )prompt'))
parse_prompt('(((a badly (formed +test )prompt')
with self.assertRaises(pyparsing.ParseException): self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test prompt'))
parse_prompt('(a (ba)dly (f)ormed +test prompt') self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test +prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test +prompt'))
with self.assertRaises(pyparsing.ParseException): self.assertEqual(Conjunction([Blend([FlattenedPrompt([Fragment('((a badly (formed +test', 1)])], [1.0])]),
parse_prompt('(a (ba)dly (f)ormed +test +prompt') parse_prompt('("((a badly (formed +test ").blend(1.0)'))
with self.assertRaises(pyparsing.ParseException):
parse_prompt('("((a badly (formed +test ").blend(1.0)')
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]), self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
parse_prompt("hamburger ((bun))")) parse_prompt("hamburger ((bun))"))
@ -128,6 +136,26 @@ class PromptParserTestCase(unittest.TestCase):
def test_blend(self): def test_blend(self):
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
parse_prompt("(\"mountain\", \"man\").blend()")
)
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
parse_prompt("(mountain, man).blend()")
)
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
parse_prompt("((mountain), (man)).blend()")
)
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('tall man', 1.0)])], [1.0, 1.0])]),
parse_prompt("((mountain), (tall man)).blend()")
)
with self.assertRaises(PromptParser.ParsingException):
print(parse_prompt("((mountain), \"cat.swap(dog)\").blend()"))
self.assertEqual(Conjunction( self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]), [Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)") parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
@ -166,10 +194,20 @@ class PromptParserTestCase(unittest.TestCase):
) )
self.assertEqual( self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]), Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]), FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0], normalize_weights=True)]),
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)') parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)')
) )
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9 * 0.9)])], weights=[1.0, -1.0], normalize_weights=False)]),
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1,no_normalize)')
)
with self.assertRaises(PromptParser.ParsingException):
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3, 0.1)")
with self.assertRaises(PromptParser.ParsingException):
parse_prompt("(\"fire\", \"fire flames\").blend(0.7)")
def test_nested(self): def test_nested(self):
@ -182,6 +220,9 @@ class PromptParserTestCase(unittest.TestCase):
def test_cross_attention_control(self): def test_cross_attention_control(self):
self.assertEqual(Conjunction([FlattenedPrompt([CrossAttentionControlSubstitute([Fragment('sun')], [Fragment('moon')])])]),
parse_prompt("sun.swap(moon)"))
self.assertEqual(Conjunction([ self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1), FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]), CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
@ -259,6 +300,12 @@ class PromptParserTestCase(unittest.TestCase):
Fragment(',', 1), Fragment('fire', 2.0)])]) Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0')) self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1),
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
])]),
parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1), Fragment('eating a', 1),
@ -343,31 +390,31 @@ class PromptParserTestCase(unittest.TestCase):
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy ("mountain, man")+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man', 1.1)]),parse_prompt('hairy ("mountain, man")+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain , \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain , \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy')) self.assertEqual(make_weighted_conjunction([('mountain , \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy')) self.assertEqual(make_weighted_conjunction([('mountain , \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy')) self.assertEqual(make_weighted_conjunction([('mountain , \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy')) self.assertEqual(make_weighted_conjunction([('mountain , \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy')) self.assertEqual(make_weighted_conjunction([('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy')) self.assertEqual(make_weighted_conjunction([('mountain , \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy')) self.assertEqual(make_weighted_conjunction([('mou)ntain , \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry ')) self.assertEqual(make_weighted_conjunction([('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry '))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( ')) self.assertEqual(make_weighted_conjunction([('mountain , \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( '))
def test_cross_attention_escaping(self): def test_cross_attention_escaping(self):
@ -433,6 +480,15 @@ class PromptParserTestCase(unittest.TestCase):
def test_single(self): def test_single(self):
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
FlattenedPrompt([("a person with a hat", 1.0),
("riding a", 1.1*1.1),
CrossAttentionControlSubstitute(
[Fragment("bicycle", pow(1.1,2))],
[Fragment("skateboard", pow(1.1,2))])
])
], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)"))
pass pass