From 94f16b1c69784ca417ae8613164e6c09fa523b54 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 18 Sep 2023 15:54:40 +1000 Subject: [PATCH] feat(ui): provide feedback when recalling invalid lora --- .../ImageMetadataActions.tsx | 4 +- .../web/src/features/nodes/types/types.ts | 6 +- .../parameters/hooks/useRecallParameters.ts | 105 +++++++++++++----- 3 files changed, 83 insertions(+), 32 deletions(-) diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index cf58334667..4f1bd39b8c 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -1,4 +1,4 @@ -import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types'; +import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; @@ -69,7 +69,7 @@ const ImageMetadataActions = (props: Props) => { }, [metadata?.strength, recallStrength]); const handleRecallLoRA = useCallback( - (lora: LoRAMetadataType) => { + (lora: LoRAMetadataItem) => { recallLoRA(lora); }, [recallLoRA] diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 6a03c07d57..c92a41391c 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -1057,12 +1057,12 @@ export const isInvocationFieldSchema = ( export type InvocationEdgeExtra = { type: 'default' | 'collapsed' }; -const zLoRAObject = z.object({ +const zLoRAMetadataItem = z.object({ lora: zLoRAModelField.deepPartial(), weight: z.number(), }); -export type LoRAMetadataType = z.infer; +export type LoRAMetadataItem = z.infer; export const zCoreMetadata = z .object({ @@ -1083,7 +1083,7 @@ export const zCoreMetadata = z .union([zMainModel.deepPartial(), zOnnxModel.deepPartial()]) .nullish(), controlnets: z.array(zControlField.deepPartial()).nullish(), - loras: z.array(zLoRAObject).nullish(), + loras: z.array(zLoRAMetadataItem).nullish(), vae: zVaeModelField.nullish(), strength: z.number().nullish(), init_image: z.string().nullish(), diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index fc33200bd3..38764ed856 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -1,6 +1,8 @@ +import { createSelector } from '@reduxjs/toolkit'; import { useAppToaster } from 'app/components/Toaster'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types'; import { refinerModelChanged, setNegativeStylePromptSDXL, @@ -12,7 +14,7 @@ import { setRefinerStart, setRefinerSteps, } from 'features/sdxl/store/sdxlSlice'; -import { useCallback } from 'react'; +import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { ImageDTO } from 'services/api/types'; import { @@ -52,10 +54,16 @@ import { isValidWidth, } from '../types/parameterSchemas'; +const selector = createSelector(stateSelector, ({ generation }) => { + const { model } = generation; + return { model }; +}); + export const useRecallParameters = () => { const dispatch = useAppDispatch(); const toaster = useAppToaster(); const { t } = useTranslation(); + const { model } = useAppSelector(selector); const parameterSetToast = useCallback(() => { toaster({ @@ -66,14 +74,18 @@ export const useRecallParameters = () => { }); }, [t, toaster]); - const parameterNotSetToast = useCallback(() => { - toaster({ - title: t('toast.parameterNotSet'), - status: 'warning', - duration: 2500, - isClosable: true, - }); - }, [t, toaster]); + const parameterNotSetToast = useCallback( + (description?: string) => { + toaster({ + title: t('toast.parameterNotSet'), + description, + status: 'warning', + duration: 2500, + isClosable: true, + }); + }, + [t, toaster] + ); const allParameterSetToast = useCallback(() => { toaster({ @@ -84,14 +96,18 @@ export const useRecallParameters = () => { }); }, [t, toaster]); - const allParameterNotSetToast = useCallback(() => { - toaster({ - title: t('toast.parametersNotSet'), - status: 'warning', - duration: 2500, - isClosable: true, - }); - }, [t, toaster]); + const allParameterNotSetToast = useCallback( + (description?: string) => { + toaster({ + title: t('toast.parametersNotSet'), + status: 'warning', + description, + duration: 2500, + isClosable: true, + }); + }, + [t, toaster] + ); /** * Recall both prompts with toast @@ -325,11 +341,10 @@ export const useRecallParameters = () => { }), }); - const recallLoRA = useCallback( - (loraMetadataItem: LoRAMetadataType) => { + const prepareLoRAMetadataItem = useCallback( + (loraMetadataItem: LoRAMetadataItem) => { if (!isValidLoRAModel(loraMetadataItem.lora)) { - parameterNotSetToast(); - return; + return { lora: null, error: 'Invalid LoRA model' }; } const { base_model, model_name } = loraMetadataItem.lora; @@ -339,17 +354,40 @@ export const useRecallParameters = () => { ); if (!matchingLoRA) { - parameterNotSetToast(); + return { lora: null, error: 'LoRA model is not installed' }; + } + + const isCompatibleBaseModel = + matchingLoRA?.base_model === model?.base_model; + + if (!isCompatibleBaseModel) { + return { + lora: null, + error: 'LoRA incompatible with currently-selected model', + }; + } + + return { lora: matchingLoRA, error: null }; + }, + [loras, model?.base_model] + ); + + const recallLoRA = useCallback( + (loraMetadataItem: LoRAMetadataItem) => { + const result = prepareLoRAMetadataItem(loraMetadataItem); + + if (!result.lora) { + parameterNotSetToast(result.error); return; } dispatch( - loraRecalled({ ...matchingLoRA, weight: loraMetadataItem.weight }) + loraRecalled({ ...result.lora, weight: loraMetadataItem.weight }) ); parameterSetToast(); }, - [loras, dispatch, parameterSetToast, parameterNotSetToast] + [prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] ); /* @@ -389,6 +427,7 @@ export const useRecallParameters = () => { refiner_positive_aesthetic_score, refiner_negative_aesthetic_score, refiner_start, + loras, } = metadata; if (isValidCfgScale(cfg_scale)) { @@ -470,9 +509,21 @@ export const useRecallParameters = () => { dispatch(setRefinerStart(refiner_start)); } + loras?.forEach((lora) => { + const result = prepareLoRAMetadataItem(lora); + if (result.lora) { + dispatch(loraRecalled({ ...result.lora, weight: lora.weight })); + } + }); + allParameterSetToast(); }, - [allParameterNotSetToast, allParameterSetToast, dispatch] + [ + allParameterNotSetToast, + allParameterSetToast, + dispatch, + prepareLoRAMetadataItem, + ] ); return {