mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): handle concat when recalling prompts
This required some minor reworking of of the logic to recall multiple items. I split this into a utility function that includes some special handling for concat. Closes #6478
This commit is contained in:
parent
89a764a359
commit
64523c4b1b
@ -1,4 +1,7 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { objectKeys } from 'common/util/objectKeys';
|
||||
import { shouldConcatPromptsChanged } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { Layer } from 'features/controlLayers/store/types';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import type {
|
||||
@ -16,6 +19,7 @@ import { validators } from 'features/metadata/util/validators';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { size } from 'lodash-es';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { parsers } from './parsers';
|
||||
@ -376,54 +380,25 @@ export const handlers = {
|
||||
}),
|
||||
} as const;
|
||||
|
||||
type ParsedValue = Awaited<ReturnType<(typeof handlers)[keyof typeof handlers]['parse']>>;
|
||||
type RecallResults = Partial<Record<keyof typeof handlers, ParsedValue>>;
|
||||
|
||||
export const parseAndRecallPrompts = async (metadata: unknown) => {
|
||||
const results = await Promise.allSettled([
|
||||
handlers.positivePrompt.parse(metadata).then((positivePrompt) => {
|
||||
if (!handlers.positivePrompt.recall) {
|
||||
return;
|
||||
}
|
||||
handlers.positivePrompt?.recall(positivePrompt);
|
||||
}),
|
||||
handlers.negativePrompt.parse(metadata).then((negativePrompt) => {
|
||||
if (!handlers.negativePrompt.recall) {
|
||||
return;
|
||||
}
|
||||
handlers.negativePrompt?.recall(negativePrompt);
|
||||
}),
|
||||
handlers.sdxlPositiveStylePrompt.parse(metadata).then((sdxlPositiveStylePrompt) => {
|
||||
if (!handlers.sdxlPositiveStylePrompt.recall) {
|
||||
return;
|
||||
}
|
||||
handlers.sdxlPositiveStylePrompt?.recall(sdxlPositiveStylePrompt);
|
||||
}),
|
||||
handlers.sdxlNegativeStylePrompt.parse(metadata).then((sdxlNegativeStylePrompt) => {
|
||||
if (!handlers.sdxlNegativeStylePrompt.recall) {
|
||||
return;
|
||||
}
|
||||
handlers.sdxlNegativeStylePrompt?.recall(sdxlNegativeStylePrompt);
|
||||
}),
|
||||
]);
|
||||
if (results.some((result) => result.status === 'fulfilled')) {
|
||||
const keysToRecall: (keyof typeof handlers)[] = [
|
||||
'positivePrompt',
|
||||
'negativePrompt',
|
||||
'sdxlPositiveStylePrompt',
|
||||
'sdxlNegativeStylePrompt',
|
||||
];
|
||||
const recalled = await recallKeys(keysToRecall, metadata);
|
||||
if (size(recalled) > 0) {
|
||||
parameterSetToast(t('metadata.allPrompts'));
|
||||
}
|
||||
};
|
||||
|
||||
export const parseAndRecallImageDimensions = async (metadata: unknown) => {
|
||||
const results = await Promise.allSettled([
|
||||
handlers.width.parse(metadata).then((width) => {
|
||||
if (!handlers.width.recall) {
|
||||
return;
|
||||
}
|
||||
handlers.width?.recall(width);
|
||||
}),
|
||||
handlers.height.parse(metadata).then((height) => {
|
||||
if (!handlers.height.recall) {
|
||||
return;
|
||||
}
|
||||
handlers.height?.recall(height);
|
||||
}),
|
||||
]);
|
||||
if (results.some((result) => result.status === 'fulfilled')) {
|
||||
const recalled = recallKeys(['width', 'height'], metadata);
|
||||
if (size(recalled) > 0) {
|
||||
parameterSetToast(t('metadata.imageDimensions'));
|
||||
}
|
||||
};
|
||||
@ -438,28 +413,20 @@ export const parseAndRecallAllMetadata = async (
|
||||
toControlLayers: boolean,
|
||||
skip: (keyof typeof handlers)[] = []
|
||||
) => {
|
||||
const skipKeys = skip ?? [];
|
||||
const skipKeys = deepClone(skip);
|
||||
if (toControlLayers) {
|
||||
skipKeys.push(...TO_CONTROL_LAYERS_SKIP_KEYS);
|
||||
} else {
|
||||
skipKeys.push(...NOT_TO_CONTROL_LAYERS_SKIP_KEYS);
|
||||
}
|
||||
const results = await Promise.allSettled(
|
||||
objectKeys(handlers)
|
||||
.filter((key) => !skipKeys.includes(key))
|
||||
.map((key) => {
|
||||
const { parse, recall } = handlers[key];
|
||||
return parse(metadata).then((value) => {
|
||||
if (!recall) {
|
||||
return;
|
||||
}
|
||||
/* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */
|
||||
recall(value);
|
||||
});
|
||||
})
|
||||
);
|
||||
|
||||
if (results.some((result) => result.status === 'fulfilled')) {
|
||||
// We may need to take some further action depending on what was recalled. For example, we need to disable SDXL prompt
|
||||
// concat if the negative or positive style prompt was set. Because the recalling is all async, we need to collect all
|
||||
// results
|
||||
const keysToRecall = objectKeys(handlers).filter((key) => !skipKeys.includes(key));
|
||||
const recalled = await recallKeys(keysToRecall, metadata);
|
||||
|
||||
if (size(recalled) > 0) {
|
||||
toast({
|
||||
id: 'PARAMETER_SET',
|
||||
title: t('toast.parametersSet'),
|
||||
@ -473,3 +440,43 @@ export const parseAndRecallAllMetadata = async (
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Recalls a set of keys from metadata.
|
||||
* Includes special handling for some metadata where recalling may have side effects. For example, recalling a "style"
|
||||
* prompt that is different from the "positive" or "negative" prompt should disable prompt concatenation.
|
||||
* @param keysToRecall An array of keys to recall.
|
||||
* @param metadata The metadata to recall from
|
||||
* @returns A promise that resolves to an object containing the recalled values.
|
||||
*/
|
||||
const recallKeys = async (keysToRecall: (keyof typeof handlers)[], metadata: unknown): Promise<RecallResults> => {
|
||||
const { dispatch } = getStore();
|
||||
const recalled: RecallResults = {};
|
||||
for (const key of keysToRecall) {
|
||||
const { parse, recall } = handlers[key];
|
||||
if (!recall) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
const value = await parse(metadata);
|
||||
/* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */
|
||||
await recall(value);
|
||||
recalled[key] = value;
|
||||
} catch {
|
||||
// no-op
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
(recalled['sdxlPositiveStylePrompt'] && recalled['sdxlPositiveStylePrompt'] !== recalled['positivePrompt']) ||
|
||||
(recalled['sdxlNegativeStylePrompt'] && recalled['sdxlNegativeStylePrompt'] !== recalled['negativePrompt'])
|
||||
) {
|
||||
// If we set the negative style prompt or positive style prompt, we should disable prompt concat
|
||||
dispatch(shouldConcatPromptsChanged(false));
|
||||
} else {
|
||||
// Otherwise, we should enable prompt concat
|
||||
dispatch(shouldConcatPromptsChanged(true));
|
||||
}
|
||||
|
||||
return recalled;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user