diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index c1124477e2..4f1bd39b8c 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -1,8 +1,9 @@ -import { CoreMetadata } from 'features/nodes/types/types'; +import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { memo, useCallback } from 'react'; -import ImageMetadataItem from './ImageMetadataItem'; import { useTranslation } from 'react-i18next'; +import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas'; +import ImageMetadataItem from './ImageMetadataItem'; type Props = { metadata?: CoreMetadata; @@ -24,6 +25,7 @@ const ImageMetadataActions = (props: Props) => { recallWidth, recallHeight, recallStrength, + recallLoRA, } = useRecallParameters(); const handleRecallPositivePrompt = useCallback(() => { @@ -66,6 +68,13 @@ const ImageMetadataActions = (props: Props) => { recallStrength(metadata?.strength); }, [metadata?.strength, recallStrength]); + const handleRecallLoRA = useCallback( + (lora: LoRAMetadataItem) => { + recallLoRA(lora); + }, + [recallLoRA] + ); + if (!metadata || Object.keys(metadata).length === 0) { return null; } @@ -130,20 +139,6 @@ const ImageMetadataActions = (props: Props) => { onClick={handleRecallHeight} /> )} - {/* {metadata.threshold !== undefined && ( - dispatch(setThreshold(Number(metadata.threshold)))} - /> - )} - {metadata.perlin !== undefined && ( - dispatch(setPerlin(Number(metadata.perlin)))} - /> - )} */} {metadata.scheduler && ( { onClick={handleRecallCfgScale} /> )} - {/* {metadata.variations && metadata.variations.length > 0 && ( - handleRecallLoRA(lora)} + /> + ); + } + })} ); }; diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index 10a1671933..bbe019b1c3 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -27,6 +27,13 @@ export const loraSlice = createSlice({ const { model_name, id, base_model } = action.payload; state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig }; }, + loraRecalled: ( + state, + action: PayloadAction + ) => { + const { model_name, id, base_model, weight } = action.payload; + state.loras[id] = { id, model_name, base_model, weight }; + }, loraRemoved: (state, action: PayloadAction) => { const id = action.payload; delete state.loras[id]; @@ -62,6 +69,7 @@ export const { loraWeightChanged, loraWeightReset, lorasCleared, + loraRecalled, } = loraSlice.actions; export default loraSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 70bc71f3fc..bb5351243b 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -1065,6 +1065,13 @@ export const isInvocationFieldSchema = ( export type InvocationEdgeExtra = { type: 'default' | 'collapsed' }; +const zLoRAMetadataItem = z.object({ + lora: zLoRAModelField.deepPartial(), + weight: z.number(), +}); + +export type LoRAMetadataItem = z.infer; + export const zCoreMetadata = z .object({ app_version: z.string().nullish(), @@ -1084,14 +1091,7 @@ export const zCoreMetadata = z .union([zMainModel.deepPartial(), zOnnxModel.deepPartial()]) .nullish(), controlnets: z.array(zControlField.deepPartial()).nullish(), - loras: z - .array( - z.object({ - lora: zLoRAModelField.deepPartial(), - weight: z.number(), - }) - ) - .nullish(), + loras: z.array(zLoRAMetadataItem).nullish(), vae: zVaeModelField.nullish(), strength: z.number().nullish(), init_image: z.string().nullish(), diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 203ff2cb1b..bc850df0d0 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -1,6 +1,8 @@ +import { createSelector } from '@reduxjs/toolkit'; import { useAppToaster } from 'app/components/Toaster'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { CoreMetadata } from 'features/nodes/types/types'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types'; import { refinerModelChanged, setNegativeStylePromptSDXL, @@ -15,6 +17,11 @@ import { import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { ImageDTO } from 'services/api/types'; +import { + loraModelsAdapter, + useGetLoRAModelsQuery, +} from '../../../services/api/endpoints/models'; +import { loraRecalled } from '../../lora/store/loraSlice'; import { initialImageSelected, modelSelected } from '../store/actions'; import { setCfgScale, @@ -30,6 +37,7 @@ import { import { isValidCfgScale, isValidHeight, + isValidLoRAModel, isValidMainModel, isValidNegativePrompt, isValidPositivePrompt, @@ -46,10 +54,16 @@ import { isValidWidth, } from '../types/parameterSchemas'; +const selector = createSelector(stateSelector, ({ generation }) => { + const { model } = generation; + return { model }; +}); + export const useRecallParameters = () => { const dispatch = useAppDispatch(); const toaster = useAppToaster(); const { t } = useTranslation(); + const { model } = useAppSelector(selector); const parameterSetToast = useCallback(() => { toaster({ @@ -60,14 +74,18 @@ export const useRecallParameters = () => { }); }, [t, toaster]); - const parameterNotSetToast = useCallback(() => { - toaster({ - title: t('toast.parameterNotSet'), - status: 'warning', - duration: 2500, - isClosable: true, - }); - }, [t, toaster]); + const parameterNotSetToast = useCallback( + (description?: string) => { + toaster({ + title: t('toast.parameterNotSet'), + description, + status: 'warning', + duration: 2500, + isClosable: true, + }); + }, + [t, toaster] + ); const allParameterSetToast = useCallback(() => { toaster({ @@ -78,14 +96,18 @@ export const useRecallParameters = () => { }); }, [t, toaster]); - const allParameterNotSetToast = useCallback(() => { - toaster({ - title: t('toast.parametersNotSet'), - status: 'warning', - duration: 2500, - isClosable: true, - }); - }, [t, toaster]); + const allParameterNotSetToast = useCallback( + (description?: string) => { + toaster({ + title: t('toast.parametersNotSet'), + status: 'warning', + description, + duration: 2500, + isClosable: true, + }); + }, + [t, toaster] + ); /** * Recall both prompts with toast @@ -307,6 +329,67 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); + /** + * Recall LoRA with toast + */ + + const { loras } = useGetLoRAModelsQuery(undefined, { + selectFromResult: (result) => ({ + loras: result.data + ? loraModelsAdapter.getSelectors().selectAll(result.data) + : [], + }), + }); + + const prepareLoRAMetadataItem = useCallback( + (loraMetadataItem: LoRAMetadataItem) => { + if (!isValidLoRAModel(loraMetadataItem.lora)) { + return { lora: null, error: 'Invalid LoRA model' }; + } + + const { base_model, model_name } = loraMetadataItem.lora; + + const matchingLoRA = loras.find( + (l) => l.base_model === base_model && l.model_name === model_name + ); + + if (!matchingLoRA) { + return { lora: null, error: 'LoRA model is not installed' }; + } + + const isCompatibleBaseModel = + matchingLoRA?.base_model === model?.base_model; + + if (!isCompatibleBaseModel) { + return { + lora: null, + error: 'LoRA incompatible with currently-selected model', + }; + } + + return { lora: matchingLoRA, error: null }; + }, + [loras, model?.base_model] + ); + + const recallLoRA = useCallback( + (loraMetadataItem: LoRAMetadataItem) => { + const result = prepareLoRAMetadataItem(loraMetadataItem); + + if (!result.lora) { + parameterNotSetToast(result.error); + return; + } + + dispatch( + loraRecalled({ ...result.lora, weight: loraMetadataItem.weight }) + ); + + parameterSetToast(); + }, + [prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] + ); + /* * Sets image as initial image with toast */ @@ -344,6 +427,7 @@ export const useRecallParameters = () => { refiner_positive_aesthetic_score, refiner_negative_aesthetic_score, refiner_start, + loras, } = metadata; if (isValidCfgScale(cfg_scale)) { @@ -425,9 +509,21 @@ export const useRecallParameters = () => { dispatch(setRefinerStart(refiner_start)); } + loras?.forEach((lora) => { + const result = prepareLoRAMetadataItem(lora); + if (result.lora) { + dispatch(loraRecalled({ ...result.lora, weight: lora.weight })); + } + }); + allParameterSetToast(); }, - [allParameterNotSetToast, allParameterSetToast, dispatch] + [ + allParameterNotSetToast, + allParameterSetToast, + dispatch, + prepareLoRAMetadataItem, + ] ); return { @@ -444,6 +540,7 @@ export const useRecallParameters = () => { recallWidth, recallHeight, recallStrength, + recallLoRA, recallAllParameters, sendToImageToImage, }; diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 9be8bd13f6..9db7762344 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -128,7 +128,7 @@ export const mainModelsAdapter = createEntityAdapter({ const onnxModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); -const loraModelsAdapter = createEntityAdapter({ +export const loraModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); export const controlNetModelsAdapter = diff --git a/pyproject.toml b/pyproject.toml index 63ac3b7c12..2ea5455c3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -198,6 +198,13 @@ output = "coverage/index.xml" max-line-length = 120 ignore = ["E203", "E266", "E501", "W503"] select = ["B", "C", "E", "F", "W", "T4"] +exclude = [ + ".git", + "__pycache__", + "build", + "dist", + "invokeai/frontend/web/node_modules/" +] [tool.black] line-length = 120