feat(ui): refactor parameter recall

- use zod to validate parameters before recalling
- update recall params hook to handle all validation and UI feedback
This commit is contained in:
psychedelicious 2023-05-31 21:05:31 +10:00
parent 062b2cf46f
commit 6571e4c2fd
15 changed files with 667 additions and 764 deletions

View File

@ -101,7 +101,8 @@
"serialize-error": "^11.0.0",
"socket.io-client": "^4.6.0",
"use-image": "^1.1.0",
"uuid": "^9.0.0"
"uuid": "^9.0.0",
"zod": "^3.21.4"
},
"peerDependencies": {
"@chakra-ui/cli": "^2.4.0",

View File

@ -568,6 +568,8 @@
"canvasMerged": "Canvas Merged",
"sentToImageToImage": "Sent To Image To Image",
"sentToUnifiedCanvas": "Sent to Unified Canvas",
"parameterSet": "Parameter set",
"parameterNotSet": "Parameter not set",
"parametersSet": "Parameters Set",
"parametersNotSet": "Parameters Not Set",
"parametersNotSetDesc": "No metadata found for this image.",

View File

@ -21,25 +21,11 @@ export const SCHEDULERS = [
export type Scheduler = (typeof SCHEDULERS)[number];
export const isScheduler = (x: string): x is Scheduler =>
SCHEDULERS.includes(x as Scheduler);
// Valid image widths
export const WIDTHS: Array<number> = Array.from(Array(64)).map(
(_x, i) => (i + 1) * 64
);
// Valid image heights
export const HEIGHTS: Array<number> = Array.from(Array(64)).map(
(_x, i) => (i + 1) * 64
);
// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
{ key: '2x', value: 2 },
{ key: '4x', value: 4 },
];
export const NUMPY_RAND_MIN = 0;
export const NUMPY_RAND_MAX = 2147483647;

View File

@ -1,316 +1,82 @@
/**
* Types for images, the things they are made of, and the things
* they make up.
*
* Generated images are txt2img and img2img images. They may have
* had additional postprocessing done on them when they were first
* generated.
*
* Postprocessed images are images which were not generated here
* but only postprocessed by the app. They only get postprocessing
* metadata and have a different image type, e.g. 'esrgan' or
* 'gfpgan'.
*/
import { SelectedImage } from 'features/parameters/store/actions';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { ImageResponseMetadata, ResourceOrigin } from 'services/api';
import { O } from 'ts-toolbelt';
/**
* TODO:
* Once an image has been generated, if it is postprocessed again,
* additional postprocessing steps are added to its postprocessing
* array.
*
* TODO: Better documentation of types.
*/
// These are old types from the model management UI
export type PromptItem = {
prompt: string;
weight: number;
};
// export type ModelStatus = 'active' | 'cached' | 'not loaded';
// TECHDEBT: We need to retain compatibility with plain prompt strings and the structure Prompt type
export type Prompt = Array<PromptItem> | string;
export type SeedWeightPair = {
seed: number;
weight: number;
};
export type SeedWeights = Array<SeedWeightPair>;
// All generated images contain these metadata.
export type CommonGeneratedImageMetadata = {
postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>;
sampler:
| 'ddim'
| 'ddpm'
| 'deis'
| 'lms'
| 'pndm'
| 'heun'
| 'heun_k'
| 'euler'
| 'euler_k'
| 'euler_a'
| 'kdpm_2'
| 'kdpm_2_a'
| 'dpmpp_2s'
| 'dpmpp_2m'
| 'dpmpp_2m_k'
| 'unipc';
prompt: Prompt;
seed: number;
variations: SeedWeights;
steps: number;
cfg_scale: number;
width: number;
height: number;
seamless: boolean;
hires_fix: boolean;
extra: null | Record<string, never>; // Pending development of RFC #266
};
// txt2img and img2img images have some unique attributes.
export type Txt2ImgMetadata = CommonGeneratedImageMetadata & {
type: 'txt2img';
};
export type Img2ImgMetadata = CommonGeneratedImageMetadata & {
type: 'img2img';
orig_hash: string;
strength: number;
fit: boolean;
init_image_path: string;
mask_image_path?: string;
};
// Superset of generated image metadata types.
export type GeneratedImageMetadata = Txt2ImgMetadata | Img2ImgMetadata;
// All post processed images contain these metadata.
export type CommonPostProcessedImageMetadata = {
orig_path: string;
orig_hash: string;
};
// esrgan and gfpgan images have some unique attributes.
export type ESRGANMetadata = CommonPostProcessedImageMetadata & {
type: 'esrgan';
scale: 2 | 4;
strength: number;
denoise_str: number;
};
export type FacetoolMetadata = CommonPostProcessedImageMetadata & {
type: 'gfpgan' | 'codeformer';
strength: number;
fidelity?: number;
};
// Superset of all postprocessed image metadata types..
export type PostProcessedImageMetadata = ESRGANMetadata | FacetoolMetadata;
// Metadata includes the system config and image metadata.
// export type Metadata = SystemGenerationMetadata & {
// image: GeneratedImageMetadata | PostProcessedImageMetadata;
// export type Model = {
// status: ModelStatus;
// description: string;
// weights: string;
// config?: string;
// vae?: string;
// width?: number;
// height?: number;
// default?: boolean;
// format?: string;
// };
/**
* ResultImage
*/
// export ty`pe Image = {
// export type DiffusersModel = {
// status: ModelStatus;
// description: string;
// repo_id?: string;
// path?: string;
// vae?: {
// repo_id?: string;
// path?: string;
// };
// format?: string;
// default?: boolean;
// };
// export type ModelList = Record<string, Model & DiffusersModel>;
// export type FoundModel = {
// name: string;
// type: image_origin;
// url: string;
// thumbnail: string;
// metadata: ImageResponseMetadata;
// location: string;
// };
// export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => {
// if ('url' in obj && 'thumbnail' in obj) {
// return true;
// }
// return false;
// export type InvokeModelConfigProps = {
// name: string | undefined;
// description: string | undefined;
// config: string | undefined;
// weights: string | undefined;
// vae: string | undefined;
// width: number | undefined;
// height: number | undefined;
// default: boolean | undefined;
// format: string | undefined;
// };
/**
* Types related to the system status.
*/
// // This represents the processing status of the backend.
// export type SystemStatus = {
// isProcessing: boolean;
// currentStep: number;
// totalSteps: number;
// currentIteration: number;
// totalIterations: number;
// currentStatus: string;
// currentStatusHasSteps: boolean;
// hasError: boolean;
// export type InvokeDiffusersModelConfigProps = {
// name: string | undefined;
// description: string | undefined;
// repo_id: string | undefined;
// path: string | undefined;
// default: boolean | undefined;
// format: string | undefined;
// vae: {
// repo_id: string | undefined;
// path: string | undefined;
// };
// };
// export type SystemGenerationMetadata = {
// model: string;
// model_weights?: string;
// model_id?: string;
// model_hash: string;
// app_id: string;
// app_version: string;
// export type InvokeModelConversionProps = {
// model_name: string;
// save_location: string;
// custom_location: string | null;
// };
// export type SystemConfig = SystemGenerationMetadata & {
// model_list: ModelList;
// infill_methods: string[];
// export type InvokeModelMergingProps = {
// models_to_merge: string[];
// alpha: number;
// interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
// force: boolean;
// merged_model_name: string;
// model_merge_save_path: string | null;
// };
export type ModelStatus = 'active' | 'cached' | 'not loaded';
export type Model = {
status: ModelStatus;
description: string;
weights: string;
config?: string;
vae?: string;
width?: number;
height?: number;
default?: boolean;
format?: string;
};
export type DiffusersModel = {
status: ModelStatus;
description: string;
repo_id?: string;
path?: string;
vae?: {
repo_id?: string;
path?: string;
};
format?: string;
default?: boolean;
};
export type ModelList = Record<string, Model & DiffusersModel>;
export type FoundModel = {
name: string;
location: string;
};
export type InvokeModelConfigProps = {
name: string | undefined;
description: string | undefined;
config: string | undefined;
weights: string | undefined;
vae: string | undefined;
width: number | undefined;
height: number | undefined;
default: boolean | undefined;
format: string | undefined;
};
export type InvokeDiffusersModelConfigProps = {
name: string | undefined;
description: string | undefined;
repo_id: string | undefined;
path: string | undefined;
default: boolean | undefined;
format: string | undefined;
vae: {
repo_id: string | undefined;
path: string | undefined;
};
};
export type InvokeModelConversionProps = {
model_name: string;
save_location: string;
custom_location: string | null;
};
export type InvokeModelMergingProps = {
models_to_merge: string[];
alpha: number;
interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
force: boolean;
merged_model_name: string;
model_merge_save_path: string | null;
};
/**
* These types type data received from the server via socketio.
*/
export type ModelChangeResponse = {
model_name: string;
model_list: ModelList;
};
export type ModelConvertedResponse = {
converted_model_name: string;
model_list: ModelList;
};
export type ModelsMergedResponse = {
merged_models: string[];
merged_model_name: string;
model_list: ModelList;
};
export type ModelAddedResponse = {
new_model_name: string;
model_list: ModelList;
update: boolean;
};
export type ModelDeletedResponse = {
deleted_model_name: string;
model_list: ModelList;
};
export type FoundModelResponse = {
search_folder: string;
found_models: FoundModel[];
};
// export type SystemStatusResponse = SystemStatus;
// export type SystemConfigResponse = SystemConfig;
export type ImageResultResponse = Omit<Image, 'uuid'> & {
boundingBox?: IRect;
generationMode: InvokeTabName;
};
export type ImageUploadResponse = {
// image: Omit<Image, 'uuid' | 'metadata' | 'category'>;
url: string;
mtime: number;
width: number;
height: number;
thumbnail: string;
// bbox: [number, number, number, number];
};
export type ErrorResponse = {
message: string;
additionalData?: string;
};
export type ImageUrlResponse = {
url: string;
};
export type UploadOutpaintingMergeImagePayload = {
dataURL: string;
name: string;
};
/**
* A disable-able application feature
*/
@ -322,7 +88,8 @@ export type AppFeature =
| 'githubLink'
| 'discordLink'
| 'bugLink'
| 'localization';
| 'localization'
| 'consoleLogging';
/**
* A disable-able Stable Diffusion feature

View File

@ -1,119 +0,0 @@
/**
* PARTIAL ZOD IMPLEMENTATION
*
* doesn't work well bc like most validators, zod is not built to skip invalid values.
* it mostly works but just seems clearer and simpler to manually parse for now.
*
* in the future it would be really nice if we could use zod for some things:
* - zodios (axios + zod): https://github.com/ecyrbe/zodios
* - openapi to zodios: https://github.com/astahmer/openapi-zod-client
*/
// import { z } from 'zod';
// const zMetadataStringField = z.string();
// export type MetadataStringField = z.infer<typeof zMetadataStringField>;
// const zMetadataIntegerField = z.number().int();
// export type MetadataIntegerField = z.infer<typeof zMetadataIntegerField>;
// const zMetadataFloatField = z.number();
// export type MetadataFloatField = z.infer<typeof zMetadataFloatField>;
// const zMetadataBooleanField = z.boolean();
// export type MetadataBooleanField = z.infer<typeof zMetadataBooleanField>;
// const zMetadataImageField = z.object({
// image_type: z.union([
// z.literal('results'),
// z.literal('uploads'),
// z.literal('intermediates'),
// ]),
// image_name: z.string().min(1),
// });
// export type MetadataImageField = z.infer<typeof zMetadataImageField>;
// const zMetadataLatentsField = z.object({
// latents_name: z.string().min(1),
// });
// export type MetadataLatentsField = z.infer<typeof zMetadataLatentsField>;
// /**
// * zod Schema for any node field. Use a `transform()` to manually parse, skipping invalid values.
// */
// const zAnyMetadataField = z.any().transform((val, ctx) => {
// // Grab the field name from the path
// const fieldName = String(ctx.path[ctx.path.length - 1]);
// // `id` and `type` must be strings if they exist
// if (['id', 'type'].includes(fieldName)) {
// const reservedStringPropertyResult = zMetadataStringField.safeParse(val);
// if (reservedStringPropertyResult.success) {
// return reservedStringPropertyResult.data;
// }
// return;
// }
// // Parse the rest of the fields, only returning the data if the parsing is successful
// const stringFieldResult = zMetadataStringField.safeParse(val);
// if (stringFieldResult.success) {
// return stringFieldResult.data;
// }
// const integerFieldResult = zMetadataIntegerField.safeParse(val);
// if (integerFieldResult.success) {
// return integerFieldResult.data;
// }
// const floatFieldResult = zMetadataFloatField.safeParse(val);
// if (floatFieldResult.success) {
// return floatFieldResult.data;
// }
// const booleanFieldResult = zMetadataBooleanField.safeParse(val);
// if (booleanFieldResult.success) {
// return booleanFieldResult.data;
// }
// const imageFieldResult = zMetadataImageField.safeParse(val);
// if (imageFieldResult.success) {
// return imageFieldResult.data;
// }
// const latentsFieldResult = zMetadataImageField.safeParse(val);
// if (latentsFieldResult.success) {
// return latentsFieldResult.data;
// }
// });
// /**
// * The node metadata schema.
// */
// const zNodeMetadata = z.object({
// session_id: z.string().min(1).optional(),
// node: z.record(z.string().min(1), zAnyMetadataField).optional(),
// });
// export type NodeMetadata = z.infer<typeof zNodeMetadata>;
// const zMetadata = z.object({
// invokeai: zNodeMetadata.optional(),
// 'sd-metadata': z.record(z.string().min(1), z.any()).optional(),
// });
// export type Metadata = z.infer<typeof zMetadata>;
// export const parseMetadata = (
// metadata: Record<string, any>
// ): Metadata | undefined => {
// const result = zMetadata.safeParse(metadata);
// if (!result.success) {
// console.log(result.error.issues);
// return;
// }
// return result.data;
// };
export default {};

View File

@ -49,7 +49,7 @@ import { useCallback } from 'react';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { useGetUrl } from 'common/util/getUrl';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useParameters } from 'features/parameters/hooks/useParameters';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import {
requestedImageDeletion,
@ -58,7 +58,6 @@ import {
} from '../store/actions';
import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
import { allParametersSet } from 'features/parameters/store/generationSlice';
import DeleteImageButton from './ImageActionButtons/DeleteImageButton';
import { useAppToaster } from 'app/components/Toaster';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
@ -165,7 +164,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const toaster = useAppToaster();
const { t } = useTranslation();
const { recallPrompt, recallSeed, recallAllParameters } = useParameters();
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
// const handleCopyImage = useCallback(async () => {
// if (!image?.url) {
@ -250,11 +250,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('s', handleUseSeed, [image]);
const handleUsePrompt = useCallback(() => {
recallPrompt(
recallBothPrompts(
image?.metadata?.positive_conditioning,
image?.metadata?.negative_conditioning
);
}, [image, recallPrompt]);
}, [image, recallBothPrompts]);
useHotkeys('p', handleUsePrompt, [image]);

View File

@ -30,7 +30,7 @@ import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useParameters } from 'features/parameters/hooks/useParameters';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import {
requestedImageDeletion,
@ -114,8 +114,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { recallSeed, recallPrompt, recallInitialImage, recallAllParameters } =
useParameters();
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
@ -154,11 +154,15 @@ const HoverableImage = memo((props: HoverableImageProps) => {
// Recall parameters handlers
const handleRecallPrompt = useCallback(() => {
recallPrompt(
recallBothPrompts(
image.metadata?.positive_conditioning,
image.metadata?.negative_conditioning
);
}, [image, recallPrompt]);
}, [
image.metadata?.negative_conditioning,
image.metadata?.positive_conditioning,
recallBothPrompts,
]);
const handleRecallSeed = useCallback(() => {
recallSeed(image.metadata?.seed);

View File

@ -31,6 +31,7 @@ import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { ImageDTO } from 'services/api';
import { Scheduler } from 'app/constants';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
type MetadataItemProps = {
isLink?: boolean;
@ -120,6 +121,21 @@ const memoEqualityCheck = (
*/
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch();
const {
recallBothPrompts,
recallPositivePrompt,
recallNegativePrompt,
recallSeed,
recallInitialImage,
recallCfgScale,
recallModel,
recallScheduler,
recallSteps,
recallWidth,
recallHeight,
recallStrength,
recallAllParameters,
} = useRecallParameters();
useHotkeys('esc', () => {
dispatch(setShouldShowImageDetails(false));
@ -166,52 +182,53 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
{metadata.type && (
<MetadataItem label="Invocation type" value={metadata.type} />
)}
{metadata.width && (
<MetadataItem
label="Width"
value={metadata.width}
onClick={() => dispatch(setWidth(Number(metadata.width)))}
/>
)}
{metadata.height && (
<MetadataItem
label="Height"
value={metadata.height}
onClick={() => dispatch(setHeight(Number(metadata.height)))}
/>
)}
{metadata.model && (
<MetadataItem label="Model" value={metadata.model} />
)}
{sessionId && <MetadataItem label="Session ID" value={sessionId} />}
{metadata.positive_conditioning && (
<MetadataItem
label="Prompt"
label="Positive Prompt"
labelPosition="top"
value={
typeof metadata.positive_conditioning === 'string'
? metadata.positive_conditioning
: promptToString(metadata.positive_conditioning)
value={metadata.positive_conditioning}
onClick={() =>
recallPositivePrompt(metadata.positive_conditioning)
}
onClick={() => setPositivePrompt(metadata.positive_conditioning!)}
/>
)}
{metadata.negative_conditioning && (
<MetadataItem
label="Prompt"
label="Negative Prompt"
labelPosition="top"
value={
typeof metadata.negative_conditioning === 'string'
? metadata.negative_conditioning
: promptToString(metadata.negative_conditioning)
value={metadata.negative_conditioning}
onClick={() =>
recallNegativePrompt(metadata.negative_conditioning)
}
onClick={() => setNegativePrompt(metadata.negative_conditioning!)}
/>
)}
{metadata.seed !== undefined && (
<MetadataItem
label="Seed"
value={metadata.seed}
onClick={() => dispatch(setSeed(Number(metadata.seed)))}
onClick={() => recallSeed(metadata.seed)}
/>
)}
{metadata.model !== undefined && (
<MetadataItem
label="Model"
value={metadata.model}
onClick={() => recallModel(metadata.model)}
/>
)}
{metadata.width && (
<MetadataItem
label="Width"
value={metadata.width}
onClick={() => recallWidth(metadata.width)}
/>
)}
{metadata.height && (
<MetadataItem
label="Height"
value={metadata.height}
onClick={() => recallHeight(metadata.height)}
/>
)}
{/* {metadata.threshold !== undefined && (
@ -232,23 +249,21 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
<MetadataItem
label="Scheduler"
value={metadata.scheduler}
onClick={() =>
dispatch(setScheduler(metadata.scheduler as Scheduler))
}
onClick={() => recallScheduler(metadata.scheduler)}
/>
)}
{metadata.steps && (
<MetadataItem
label="Steps"
value={metadata.steps}
onClick={() => dispatch(setSteps(Number(metadata.steps)))}
onClick={() => recallSteps(metadata.steps)}
/>
)}
{metadata.cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={metadata.cfg_scale}
onClick={() => dispatch(setCfgScale(Number(metadata.cfg_scale)))}
onClick={() => recallCfgScale(metadata.cfg_scale)}
/>
)}
{/* {metadata.variations && metadata.variations.length > 0 && (
@ -289,9 +304,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
<MetadataItem
label="Image to image strength"
value={metadata.strength}
onClick={() =>
dispatch(setImg2imgStrength(Number(metadata.strength)))
}
onClick={() => recallStrength(metadata.strength)}
/>
)}
{/* {metadata.fit && (

View File

@ -1,151 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { isFinite, isString } from 'lodash-es';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import useSetBothPrompts from './usePrompt';
import { allParametersSet, setSeed } from '../store/generationSlice';
import { isImageField } from 'services/types/guards';
import { NUMPY_RAND_MAX } from 'app/constants';
import { initialImageSelected } from '../store/actions';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useAppToaster } from 'app/components/Toaster';
import { ImageDTO } from 'services/api';
export const useParameters = () => {
const dispatch = useAppDispatch();
const toaster = useAppToaster();
const { t } = useTranslation();
const setBothPrompts = useSetBothPrompts();
/**
* Sets prompt with toast
*/
const recallPrompt = useCallback(
(prompt: unknown, negativePrompt?: unknown) => {
if (!isString(prompt) || !isString(negativePrompt)) {
toaster({
title: t('toast.promptNotSet'),
description: t('toast.promptNotSetDesc'),
status: 'warning',
duration: 2500,
isClosable: true,
});
return;
}
setBothPrompts(prompt, negativePrompt);
toaster({
title: t('toast.promptSet'),
status: 'info',
duration: 2500,
isClosable: true,
});
},
[t, toaster, setBothPrompts]
);
/**
* Sets seed with toast
*/
const recallSeed = useCallback(
(seed: unknown) => {
const s = Number(seed);
if (!isFinite(s) || (isFinite(s) && !(s >= 0 && s <= NUMPY_RAND_MAX))) {
toaster({
title: t('toast.seedNotSet'),
description: t('toast.seedNotSetDesc'),
status: 'warning',
duration: 2500,
isClosable: true,
});
return;
}
dispatch(setSeed(s));
toaster({
title: t('toast.seedSet'),
status: 'info',
duration: 2500,
isClosable: true,
});
},
[t, toaster, dispatch]
);
/**
* Sets initial image with toast
*/
const recallInitialImage = useCallback(
async (image: unknown) => {
if (!isImageField(image)) {
toaster({
title: t('toast.initialImageNotSet'),
description: t('toast.initialImageNotSetDesc'),
status: 'warning',
duration: 2500,
isClosable: true,
});
return;
}
dispatch(initialImageSelected(image.image_name));
toaster({
title: t('toast.initialImageSet'),
status: 'info',
duration: 2500,
isClosable: true,
});
},
[t, toaster, dispatch]
);
/**
* Sets image as initial image with toast
*/
const sendToImageToImage = useCallback(
(image: ImageDTO) => {
dispatch(initialImageSelected(image));
},
[dispatch]
);
const recallAllParameters = useCallback(
(image: ImageDTO | undefined) => {
const type = image?.metadata?.type;
// not sure what this list should be
if (['t2l', 'l2l', 'inpaint'].includes(String(type))) {
dispatch(allParametersSet(image));
if (image?.metadata?.type === 'l2l') {
dispatch(setActiveTab('img2img'));
} else if (image?.metadata?.type === 't2l') {
dispatch(setActiveTab('txt2img'));
}
toaster({
title: t('toast.parametersSet'),
status: 'success',
duration: 2500,
isClosable: true,
});
} else {
toaster({
title: t('toast.parametersNotSet'),
description: t('toast.parametersNotSetDesc'),
status: 'error',
duration: 2500,
isClosable: true,
});
}
},
[t, toaster, dispatch]
);
return {
recallPrompt,
recallSeed,
recallInitialImage,
sendToImageToImage,
recallAllParameters,
};
};

View File

@ -1,23 +0,0 @@
import { getPromptAndNegative } from 'common/util/getPromptAndNegative';
import * as InvokeAI from 'app/types/invokeai';
import promptToString from 'common/util/promptToString';
import { useAppDispatch } from 'app/store/storeHooks';
import { setNegativePrompt, setPositivePrompt } from '../store/generationSlice';
import { useCallback } from 'react';
// TECHDEBT: We have two metadata prompt formats and need to handle recalling either of them.
// This hook provides a function to do that.
const useSetBothPrompts = () => {
const dispatch = useAppDispatch();
return useCallback(
(inputPrompt: InvokeAI.Prompt, negativePrompt: InvokeAI.Prompt) => {
dispatch(setPositivePrompt(inputPrompt));
dispatch(setNegativePrompt(negativePrompt));
},
[dispatch]
);
};
export default useSetBothPrompts;

View File

@ -0,0 +1,348 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
modelSelected,
setCfgScale,
setHeight,
setImg2imgStrength,
setNegativePrompt,
setPositivePrompt,
setScheduler,
setSeed,
setSteps,
setWidth,
} from '../store/generationSlice';
import { isImageField } from 'services/types/guards';
import { initialImageSelected } from '../store/actions';
import { useAppToaster } from 'app/components/Toaster';
import { ImageDTO } from 'services/api';
import {
isValidCfgScale,
isValidHeight,
isValidModel,
isValidNegativePrompt,
isValidPositivePrompt,
isValidScheduler,
isValidSeed,
isValidSteps,
isValidStrength,
isValidWidth,
} from '../store/parameterZodSchemas';
export const useRecallParameters = () => {
const dispatch = useAppDispatch();
const toaster = useAppToaster();
const { t } = useTranslation();
const parameterSetToast = useCallback(() => {
toaster({
title: t('toast.parameterSet'),
status: 'info',
duration: 2500,
isClosable: true,
});
}, [t, toaster]);
const parameterNotSetToast = useCallback(() => {
toaster({
title: t('toast.parameterNotSet'),
status: 'warning',
duration: 2500,
isClosable: true,
});
}, [t, toaster]);
const allParameterSetToast = useCallback(() => {
toaster({
title: t('toast.parametersSet'),
status: 'info',
duration: 2500,
isClosable: true,
});
}, [t, toaster]);
const allParameterNotSetToast = useCallback(() => {
toaster({
title: t('toast.parametersNotSet'),
status: 'warning',
duration: 2500,
isClosable: true,
});
}, [t, toaster]);
/**
* Recall both prompts with toast
*/
const recallBothPrompts = useCallback(
(positivePrompt: unknown, negativePrompt: unknown) => {
if (
isValidPositivePrompt(positivePrompt) ||
isValidNegativePrompt(negativePrompt)
) {
if (isValidPositivePrompt(positivePrompt)) {
dispatch(setPositivePrompt(positivePrompt));
}
if (isValidNegativePrompt(negativePrompt)) {
dispatch(setNegativePrompt(negativePrompt));
}
parameterSetToast();
return;
}
parameterNotSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall positive prompt with toast
*/
const recallPositivePrompt = useCallback(
(positivePrompt: unknown) => {
if (!isValidPositivePrompt(positivePrompt)) {
parameterNotSetToast();
return;
}
dispatch(setPositivePrompt(positivePrompt));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall negative prompt with toast
*/
const recallNegativePrompt = useCallback(
(negativePrompt: unknown) => {
if (!isValidNegativePrompt(negativePrompt)) {
parameterNotSetToast();
return;
}
dispatch(setNegativePrompt(negativePrompt));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall seed with toast
*/
const recallSeed = useCallback(
(seed: unknown) => {
if (!isValidSeed(seed)) {
parameterNotSetToast();
return;
}
dispatch(setSeed(seed));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall CFG scale with toast
*/
const recallCfgScale = useCallback(
(cfgScale: unknown) => {
if (!isValidCfgScale(cfgScale)) {
parameterNotSetToast();
return;
}
dispatch(setCfgScale(cfgScale));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall model with toast
*/
const recallModel = useCallback(
(model: unknown) => {
if (!isValidModel(model)) {
parameterNotSetToast();
return;
}
dispatch(modelSelected(model));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall scheduler with toast
*/
const recallScheduler = useCallback(
(scheduler: unknown) => {
if (!isValidScheduler(scheduler)) {
parameterNotSetToast();
return;
}
dispatch(setScheduler(scheduler));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall steps with toast
*/
const recallSteps = useCallback(
(steps: unknown) => {
if (!isValidSteps(steps)) {
parameterNotSetToast();
return;
}
dispatch(setSteps(steps));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall width with toast
*/
const recallWidth = useCallback(
(width: unknown) => {
if (!isValidWidth(width)) {
parameterNotSetToast();
return;
}
dispatch(setWidth(width));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall height with toast
*/
const recallHeight = useCallback(
(height: unknown) => {
if (!isValidHeight(height)) {
parameterNotSetToast();
return;
}
dispatch(setHeight(height));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall strength with toast
*/
const recallStrength = useCallback(
(strength: unknown) => {
if (!isValidStrength(strength)) {
parameterNotSetToast();
return;
}
dispatch(setImg2imgStrength(strength));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Sets initial image with toast
*/
const recallInitialImage = useCallback(
async (image: unknown) => {
if (!isImageField(image)) {
parameterNotSetToast();
return;
}
dispatch(initialImageSelected(image.image_name));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Sets image as initial image with toast
*/
const sendToImageToImage = useCallback(
(image: ImageDTO) => {
dispatch(initialImageSelected(image));
},
[dispatch]
);
const recallAllParameters = useCallback(
(image: ImageDTO | undefined) => {
if (!image || !image.metadata) {
allParameterNotSetToast();
return;
}
const {
cfg_scale,
height,
model,
positive_conditioning,
negative_conditioning,
scheduler,
seed,
steps,
width,
strength,
clip,
extra,
latents,
unet,
vae,
} = image.metadata;
if (isValidCfgScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));
}
if (isValidModel(model)) {
dispatch(modelSelected(model));
}
if (isValidPositivePrompt(positive_conditioning)) {
dispatch(setPositivePrompt(positive_conditioning));
}
if (isValidNegativePrompt(negative_conditioning)) {
dispatch(setNegativePrompt(negative_conditioning));
}
if (isValidScheduler(scheduler)) {
dispatch(setScheduler(scheduler));
}
if (isValidSeed(seed)) {
dispatch(setSeed(seed));
}
if (isValidSteps(steps)) {
dispatch(setSteps(steps));
}
if (isValidWidth(width)) {
dispatch(setWidth(width));
}
if (isValidHeight(height)) {
dispatch(setHeight(height));
}
if (isValidStrength(strength)) {
dispatch(setImg2imgStrength(strength));
}
allParameterSetToast();
},
[allParameterNotSetToast, allParameterSetToast, dispatch]
);
return {
recallBothPrompts,
recallPositivePrompt,
recallNegativePrompt,
recallSeed,
recallInitialImage,
recallCfgScale,
recallModel,
recallScheduler,
recallSteps,
recallWidth,
recallHeight,
recallStrength,
recallAllParameters,
sendToImageToImage,
};
};

View File

@ -1,44 +1,53 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai';
import promptToString from 'common/util/promptToString';
import { clamp, sortBy } from 'lodash-es';
import { setAllParametersReducer } from './setAllParametersReducer';
import { receivedModels } from 'services/thunks/model';
import { Scheduler } from 'app/constants';
import { ImageDTO } from 'services/api';
import { configChanged } from 'features/system/store/configSlice';
import {
CfgScaleParam,
HeightParam,
ModelParam,
NegativePromptParam,
PositivePromptParam,
SchedulerParam,
SeedParam,
StepsParam,
StrengthParam,
WidthParam,
} from './parameterZodSchemas';
export interface GenerationState {
cfgScale: number;
height: number;
img2imgStrength: number;
cfgScale: CfgScaleParam;
height: HeightParam;
img2imgStrength: StrengthParam;
infillMethod: string;
initialImage?: ImageDTO;
iterations: number;
perlin: number;
positivePrompt: string;
negativePrompt: string;
scheduler: Scheduler;
positivePrompt: PositivePromptParam;
negativePrompt: NegativePromptParam;
scheduler: SchedulerParam;
seamBlur: number;
seamSize: number;
seamSteps: number;
seamStrength: number;
seed: number;
seed: SeedParam;
seedWeights: string;
shouldFitToWidthHeight: boolean;
shouldGenerateVariations: boolean;
shouldRandomizeSeed: boolean;
shouldUseNoiseSettings: boolean;
steps: number;
steps: StepsParam;
threshold: number;
tileSize: number;
variationAmount: number;
width: number;
width: WidthParam;
shouldUseSymmetry: boolean;
horizontalSymmetrySteps: number;
verticalSymmetrySteps: number;
model: string;
model: ModelParam;
shouldUseSeamless: boolean;
seamlessXAxis: boolean;
seamlessYAxis: boolean;
@ -84,27 +93,11 @@ export const generationSlice = createSlice({
name: 'generation',
initialState,
reducers: {
setPositivePrompt: (
state,
action: PayloadAction<string | InvokeAI.Prompt>
) => {
const newPrompt = action.payload;
if (typeof newPrompt === 'string') {
state.positivePrompt = newPrompt;
} else {
state.positivePrompt = promptToString(newPrompt);
}
setPositivePrompt: (state, action: PayloadAction<string>) => {
state.positivePrompt = action.payload;
},
setNegativePrompt: (
state,
action: PayloadAction<string | InvokeAI.Prompt>
) => {
const newPrompt = action.payload;
if (typeof newPrompt === 'string') {
state.negativePrompt = newPrompt;
} else {
state.negativePrompt = promptToString(newPrompt);
}
setNegativePrompt: (state, action: PayloadAction<string>) => {
state.negativePrompt = action.payload;
},
setIterations: (state, action: PayloadAction<number>) => {
state.iterations = action.payload;
@ -175,7 +168,6 @@ export const generationSlice = createSlice({
state.shouldGenerateVariations = true;
state.variationAmount = 0;
},
allParametersSet: setAllParametersReducer,
resetParametersState: (state) => {
return {
...state,
@ -279,7 +271,6 @@ export const {
setSeamless,
setSeamlessXAxis,
setSeamlessYAxis,
allParametersSet,
} = generationSlice.actions;
export default generationSlice.reducer;

View File

@ -0,0 +1,156 @@
import { NUMPY_RAND_MAX, SCHEDULERS } from 'app/constants';
import { z } from 'zod';
/**
* These zod schemas should match the pydantic node schemas.
*
* Parameters only need schemas if we want to recall them from metadata.
*
* Each parameter needs:
* - a zod schema
* - a type alias, inferred from the zod schema
* - a combo validation/type guard function, which returns true if the value is valid
*/
/**
* Zod schema for positive prompt parameter
*/
export const zPositivePrompt = z.string();
/**
* Type alias for positive prompt parameter, inferred from its zod schema
*/
export type PositivePromptParam = z.infer<typeof zPositivePrompt>;
/**
* Validates/type-guards a value as a positive prompt parameter
*/
export const isValidPositivePrompt = (
val: unknown
): val is PositivePromptParam => zPositivePrompt.safeParse(val).success;
/**
* Zod schema for negative prompt parameter
*/
export const zNegativePrompt = z.string();
/**
* Type alias for negative prompt parameter, inferred from its zod schema
*/
export type NegativePromptParam = z.infer<typeof zNegativePrompt>;
/**
* Validates/type-guards a value as a negative prompt parameter
*/
export const isValidNegativePrompt = (
val: unknown
): val is NegativePromptParam => zNegativePrompt.safeParse(val).success;
/**
* Zod schema for steps parameter
*/
export const zSteps = z.number().int().min(1);
/**
* Type alias for steps parameter, inferred from its zod schema
*/
export type StepsParam = z.infer<typeof zSteps>;
/**
* Validates/type-guards a value as a steps parameter
*/
export const isValidSteps = (val: unknown): val is StepsParam =>
zSteps.safeParse(val).success;
/**
* Zod schema for CFG scale parameter
*/
export const zCfgScale = z.number().min(1);
/**
* Type alias for CFG scale parameter, inferred from its zod schema
*/
export type CfgScaleParam = z.infer<typeof zCfgScale>;
/**
* Validates/type-guards a value as a CFG scale parameter
*/
export const isValidCfgScale = (val: unknown): val is CfgScaleParam =>
zCfgScale.safeParse(val).success;
/**
* Zod schema for scheduler parameter
*/
export const zScheduler = z.enum(SCHEDULERS);
/**
* Type alias for scheduler parameter, inferred from its zod schema
*/
export type SchedulerParam = z.infer<typeof zScheduler>;
/**
* Validates/type-guards a value as a scheduler parameter
*/
export const isValidScheduler = (val: unknown): val is SchedulerParam =>
zScheduler.safeParse(val).success;
/**
* Zod schema for seed parameter
*/
export const zSeed = z.number().int().min(0).max(NUMPY_RAND_MAX);
/**
* Type alias for seed parameter, inferred from its zod schema
*/
export type SeedParam = z.infer<typeof zSeed>;
/**
* Validates/type-guards a value as a seed parameter
*/
export const isValidSeed = (val: unknown): val is SeedParam =>
zSeed.safeParse(val).success;
/**
* Zod schema for width parameter
*/
export const zWidth = z.number().multipleOf(8).min(64);
/**
* Type alias for width parameter, inferred from its zod schema
*/
export type WidthParam = z.infer<typeof zWidth>;
/**
* Validates/type-guards a value as a width parameter
*/
export const isValidWidth = (val: unknown): val is WidthParam =>
zWidth.safeParse(val).success;
/**
* Zod schema for height parameter
*/
export const zHeight = z.number().multipleOf(8).min(64);
/**
* Type alias for height parameter, inferred from its zod schema
*/
export type HeightParam = z.infer<typeof zHeight>;
/**
* Validates/type-guards a value as a height parameter
*/
export const isValidHeight = (val: unknown): val is HeightParam =>
zHeight.safeParse(val).success;
/**
* Zod schema for model parameter
* TODO: Make this a dynamically generated enum?
*/
export const zModel = z.string();
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type ModelParam = z.infer<typeof zModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidModel = (val: unknown): val is ModelParam =>
zModel.safeParse(val).success;
/**
* Zod schema for l2l strength parameter
*/
export const zStrength = z.number().min(0).max(1);
/**
* Type alias for l2l strength parameter, inferred from its zod schema
*/
export type StrengthParam = z.infer<typeof zStrength>;
/**
* Validates/type-guards a value as a l2l strength parameter
*/
export const isValidStrength = (val: unknown): val is StrengthParam =>
zStrength.safeParse(val).success;

View File

@ -1,77 +0,0 @@
import { Draft, PayloadAction } from '@reduxjs/toolkit';
import { GenerationState } from './generationSlice';
import { ImageDTO, ImageToImageInvocation } from 'services/api';
import { isScheduler } from 'app/constants';
export const setAllParametersReducer = (
state: Draft<GenerationState>,
action: PayloadAction<ImageDTO | undefined>
) => {
const metadata = action.payload?.metadata;
if (!metadata) {
return;
}
// not sure what this list should be
if (
metadata.type === 't2l' ||
metadata.type === 'l2l' ||
metadata.type === 'inpaint'
) {
const {
cfg_scale,
height,
model,
positive_conditioning,
negative_conditioning,
scheduler,
seed,
steps,
width,
} = metadata;
if (cfg_scale !== undefined) {
state.cfgScale = Number(cfg_scale);
}
if (height !== undefined) {
state.height = Number(height);
}
if (model !== undefined) {
state.model = String(model);
}
if (positive_conditioning !== undefined) {
state.positivePrompt = String(positive_conditioning);
}
if (negative_conditioning !== undefined) {
state.negativePrompt = String(negative_conditioning);
}
if (scheduler !== undefined) {
const schedulerString = String(scheduler);
if (isScheduler(schedulerString)) {
state.scheduler = schedulerString;
}
}
if (seed !== undefined) {
state.seed = Number(seed);
state.shouldRandomizeSeed = false;
}
if (steps !== undefined) {
state.steps = Number(steps);
}
if (width !== undefined) {
state.width = Number(width);
}
}
if (metadata.type === 'l2l') {
const { fit, image } = metadata as ImageToImageInvocation;
if (fit !== undefined) {
state.shouldFitToWidthHeight = Boolean(fit);
}
// if (image !== undefined) {
// state.initialImage = image;
// }
}
};

View File

@ -6877,6 +6877,11 @@ z-schema@~5.0.2:
optionalDependencies:
commander "^10.0.0"
zod@^3.21.4:
version "3.21.4"
resolved "https://registry.yarnpkg.com/zod/-/zod-3.21.4.tgz#10882231d992519f0a10b5dd58a38c9dabbb64db"
integrity sha512-m46AKbrzKVzOzs/DZgVnG5H55N1sv1M8qZU3A8RIKbs3mrACDNeIOeilDymVb2HdmP8uwshOCF4uJ8uM9rCqJw==
zustand@^4.3.1:
version "4.3.7"
resolved "https://registry.yarnpkg.com/zustand/-/zustand-4.3.7.tgz#501b1f0393a7f1d103332e45ab574be5747fedce"